@huggingface/transformers 3.0.0-alpha.9 → 3.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +82 -50
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +2550 -2552
- package/dist/transformers.cjs.map +1 -1
- package/dist/transformers.js +3639 -3567
- package/dist/transformers.js.map +1 -1
- package/dist/transformers.min.cjs +25 -25
- package/dist/transformers.min.cjs.map +1 -1
- package/dist/transformers.min.js +41 -42
- package/dist/transformers.min.js.map +1 -1
- package/dist/transformers.min.mjs +56 -57
- package/dist/transformers.min.mjs.map +1 -1
- package/dist/transformers.mjs +2586 -2564
- package/dist/transformers.mjs.map +1 -1
- package/package.json +14 -13
- package/src/backends/onnx.js +24 -19
- package/src/configs.js +19 -4
- package/src/env.js +5 -9
- package/src/generation/logits_process.js +40 -37
- package/src/models.js +356 -539
- package/src/ops/registry.js +14 -3
- package/src/pipelines.js +5 -5
- package/src/processors.js +392 -351
- package/src/tokenizers.js +140 -175
- package/src/utils/constants.js +1 -1
- package/src/utils/core.js +12 -0
- package/src/utils/data-structures.js +13 -11
- package/src/utils/hub.js +1 -1
- package/src/utils/maths.js +14 -5
- package/src/utils/tensor.js +60 -13
- package/types/backends/onnx.d.ts +5 -2
- package/types/backends/onnx.d.ts.map +1 -1
- package/types/configs.d.ts +29 -3
- package/types/configs.d.ts.map +1 -1
- package/types/env.d.ts +4 -2
- package/types/env.d.ts.map +1 -1
- package/types/generation/logits_process.d.ts.map +1 -1
- package/types/models.d.ts +116 -289
- package/types/models.d.ts.map +1 -1
- package/types/ops/registry.d.ts +6 -6
- package/types/ops/registry.d.ts.map +1 -1
- package/types/pipelines.d.ts +1 -2
- package/types/pipelines.d.ts.map +1 -1
- package/types/processors.d.ts +58 -51
- package/types/processors.d.ts.map +1 -1
- package/types/tokenizers.d.ts +23 -32
- package/types/tokenizers.d.ts.map +1 -1
- package/types/utils/constants.d.ts +1 -1
- package/types/utils/constants.d.ts.map +1 -1
- package/types/utils/core.d.ts +7 -0
- package/types/utils/core.d.ts.map +1 -1
- package/types/utils/data-structures.d.ts +6 -6
- package/types/utils/data-structures.d.ts.map +1 -1
- package/types/utils/hub.d.ts +1 -1
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/maths.d.ts +2 -2
- package/types/utils/maths.d.ts.map +1 -1
- package/types/utils/tensor.d.ts +27 -1
- package/types/utils/tensor.d.ts.map +1 -1
package/src/models.js
CHANGED
|
@@ -71,6 +71,10 @@ import {
|
|
|
71
71
|
getModelJSON,
|
|
72
72
|
} from './utils/hub.js';
|
|
73
73
|
|
|
74
|
+
import {
|
|
75
|
+
GITHUB_ISSUE_URL,
|
|
76
|
+
} from './utils/constants.js';
|
|
77
|
+
|
|
74
78
|
import {
|
|
75
79
|
LogitsProcessorList,
|
|
76
80
|
ForcedBOSTokenLogitsProcessor,
|
|
@@ -142,11 +146,12 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
|
|
|
142
146
|
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
|
|
143
147
|
* @param {string} fileName The name of the model file.
|
|
144
148
|
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
|
|
145
|
-
* @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
|
|
149
|
+
* @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
|
|
146
150
|
* @private
|
|
147
151
|
*/
|
|
148
152
|
async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
149
|
-
|
|
153
|
+
const custom_config = options.config?.['transformers.js_config'] ?? {};
|
|
154
|
+
let device = options.device ?? custom_config.device;
|
|
150
155
|
if (device && typeof device !== 'string') {
|
|
151
156
|
if (device.hasOwnProperty(fileName)) {
|
|
152
157
|
device = device[fileName];
|
|
@@ -164,7 +169,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
164
169
|
|
|
165
170
|
// If options.dtype is specified, we use it to choose the suffix for the model file.
|
|
166
171
|
// Otherwise, we use the default dtype for the device.
|
|
167
|
-
let dtype = options.dtype;
|
|
172
|
+
let dtype = options.dtype ?? custom_config.dtype;
|
|
168
173
|
if (typeof dtype !== 'string') {
|
|
169
174
|
if (dtype && dtype.hasOwnProperty(fileName)) {
|
|
170
175
|
dtype = dtype[fileName];
|
|
@@ -182,27 +187,54 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
182
187
|
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
|
|
183
188
|
}
|
|
184
189
|
|
|
190
|
+
// Only valid for models with a decoder
|
|
191
|
+
const kv_cache_dtype = custom_config.kv_cache_dtype
|
|
192
|
+
? (typeof custom_config.kv_cache_dtype === 'string'
|
|
193
|
+
? custom_config.kv_cache_dtype
|
|
194
|
+
: custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
|
|
195
|
+
: undefined;
|
|
196
|
+
|
|
197
|
+
if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
|
|
198
|
+
throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
const session_config = {
|
|
202
|
+
dtype: selectedDtype,
|
|
203
|
+
kv_cache_dtype,
|
|
204
|
+
}
|
|
205
|
+
|
|
185
206
|
// Construct the model file name
|
|
186
207
|
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
|
|
187
208
|
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
|
|
188
209
|
|
|
189
|
-
const session_options = { ...options.session_options }
|
|
210
|
+
const session_options = { ...options.session_options };
|
|
190
211
|
|
|
191
212
|
// Overwrite `executionProviders` if not specified
|
|
192
213
|
session_options.executionProviders ??= executionProviders;
|
|
193
214
|
|
|
215
|
+
// Overwrite `freeDimensionOverrides` if specified in config and not set in session options
|
|
216
|
+
const free_dimension_overrides = custom_config.free_dimension_overrides;
|
|
217
|
+
if (free_dimension_overrides) {
|
|
218
|
+
session_options.freeDimensionOverrides ??= free_dimension_overrides;
|
|
219
|
+
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
|
|
220
|
+
console.warn(
|
|
221
|
+
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
|
|
222
|
+
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
|
|
223
|
+
);
|
|
224
|
+
}
|
|
194
225
|
|
|
195
226
|
const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
|
|
196
227
|
|
|
197
228
|
// handle onnx external data files
|
|
229
|
+
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
|
|
198
230
|
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
|
|
199
231
|
let externalDataPromises = [];
|
|
200
|
-
if (
|
|
201
|
-
|
|
232
|
+
if (use_external_data_format && (
|
|
233
|
+
use_external_data_format === true ||
|
|
202
234
|
(
|
|
203
|
-
typeof
|
|
204
|
-
|
|
205
|
-
|
|
235
|
+
typeof use_external_data_format === 'object' &&
|
|
236
|
+
use_external_data_format.hasOwnProperty(fileName) &&
|
|
237
|
+
use_external_data_format[fileName] === true
|
|
206
238
|
)
|
|
207
239
|
)) {
|
|
208
240
|
if (apis.IS_NODE_ENV) {
|
|
@@ -236,6 +268,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
236
268
|
});
|
|
237
269
|
if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
|
|
238
270
|
// Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
|
|
271
|
+
/** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
|
|
239
272
|
const preferredOutputLocation = {};
|
|
240
273
|
for (const key in shapes) {
|
|
241
274
|
preferredOutputLocation[key] = 'gpu-buffer';
|
|
@@ -245,7 +278,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
245
278
|
}
|
|
246
279
|
|
|
247
280
|
const buffer = await bufferPromise;
|
|
248
|
-
|
|
281
|
+
|
|
282
|
+
return { buffer, session_options, session_config };
|
|
249
283
|
}
|
|
250
284
|
|
|
251
285
|
/**
|
|
@@ -260,13 +294,30 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
260
294
|
async function constructSessions(pretrained_model_name_or_path, names, options) {
|
|
261
295
|
return Object.fromEntries(await Promise.all(
|
|
262
296
|
Object.keys(names).map(async (name) => {
|
|
263
|
-
const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options);
|
|
264
|
-
const session = await createInferenceSession(buffer, session_options);
|
|
297
|
+
const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
|
|
298
|
+
const session = await createInferenceSession(buffer, session_options, session_config);
|
|
265
299
|
return [name, session];
|
|
266
300
|
})
|
|
267
301
|
));
|
|
268
302
|
}
|
|
269
303
|
|
|
304
|
+
/**
|
|
305
|
+
* Helper function to load multiple optional configuration files
|
|
306
|
+
* @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
|
|
307
|
+
* @param {Record<string, string>} names The names of the config files to load.
|
|
308
|
+
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
|
|
309
|
+
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects.
|
|
310
|
+
* @private
|
|
311
|
+
*/
|
|
312
|
+
async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
|
|
313
|
+
return Object.fromEntries(await Promise.all(
|
|
314
|
+
Object.keys(names).map(async (name) => {
|
|
315
|
+
const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options);
|
|
316
|
+
return [name, config];
|
|
317
|
+
})
|
|
318
|
+
));
|
|
319
|
+
}
|
|
320
|
+
|
|
270
321
|
/**
|
|
271
322
|
* Validate model inputs
|
|
272
323
|
* @param {Object} session The InferenceSession object that will be run.
|
|
@@ -360,7 +411,7 @@ function replaceTensors(obj) {
|
|
|
360
411
|
|
|
361
412
|
/**
|
|
362
413
|
* Converts an array or Tensor of integers to an int64 Tensor.
|
|
363
|
-
* @param {
|
|
414
|
+
* @param {any[]|Tensor} items The input integers to be converted.
|
|
364
415
|
* @returns {Tensor} The int64 Tensor with the converted values.
|
|
365
416
|
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
|
|
366
417
|
* @private
|
|
@@ -393,37 +444,6 @@ function toI64Tensor(items) {
|
|
|
393
444
|
}
|
|
394
445
|
}
|
|
395
446
|
|
|
396
|
-
/**
|
|
397
|
-
* Prepares an attention mask for a sequence of tokens based on configuration options.
|
|
398
|
-
* @param {Object} self The calling object instance.
|
|
399
|
-
* @param {Tensor} tokens The input tokens.
|
|
400
|
-
* @returns {Tensor} The attention mask tensor.
|
|
401
|
-
* @private
|
|
402
|
-
*/
|
|
403
|
-
function prepareAttentionMask(self, tokens) {
|
|
404
|
-
|
|
405
|
-
// Prepare attention mask
|
|
406
|
-
let pad_token_id = self.config.pad_token_id ?? null;
|
|
407
|
-
let eos_token_id = self.config.eos_token_id ?? null;
|
|
408
|
-
if (isIntegralNumber(eos_token_id)) {
|
|
409
|
-
eos_token_id = [eos_token_id];
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
let is_pad_token_in_inputs = tokens.indexOf(pad_token_id) !== -1;
|
|
413
|
-
let is_pad_token_not_equal_to_eos_token_id = (eos_token_id === null) || !eos_token_id.includes(pad_token_id)
|
|
414
|
-
|
|
415
|
-
if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
|
|
416
|
-
let data = BigInt64Array.from(
|
|
417
|
-
// Note: != so that int matches bigint
|
|
418
|
-
// @ts-ignore
|
|
419
|
-
tokens.data.map(x => x != pad_token_id)
|
|
420
|
-
)
|
|
421
|
-
return new Tensor('int64', data, tokens.dims)
|
|
422
|
-
} else {
|
|
423
|
-
return ones_like(tokens);
|
|
424
|
-
}
|
|
425
|
-
}
|
|
426
|
-
|
|
427
447
|
/**
|
|
428
448
|
* Creates a boolean tensor with a single value.
|
|
429
449
|
* @param {boolean} value The value of the tensor.
|
|
@@ -694,8 +714,8 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
|
|
|
694
714
|
} else {
|
|
695
715
|
return decoder_prepare_inputs_for_generation(self, ...args);
|
|
696
716
|
}
|
|
697
|
-
|
|
698
717
|
}
|
|
718
|
+
|
|
699
719
|
//////////////////////////////////////////////////
|
|
700
720
|
|
|
701
721
|
//////////////////////////////////////////////////
|
|
@@ -709,12 +729,14 @@ export class PreTrainedModel extends Callable {
|
|
|
709
729
|
* Creates a new instance of the `PreTrainedModel` class.
|
|
710
730
|
* @param {import('./configs.js').PretrainedConfig} config The model configuration.
|
|
711
731
|
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
732
|
+
* @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json).
|
|
712
733
|
*/
|
|
713
|
-
constructor(config, sessions) {
|
|
734
|
+
constructor(config, sessions, configs) {
|
|
714
735
|
super();
|
|
715
736
|
|
|
716
737
|
this.config = config;
|
|
717
738
|
this.sessions = sessions;
|
|
739
|
+
this.configs = configs;
|
|
718
740
|
|
|
719
741
|
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
|
|
720
742
|
const modelType = MODEL_TYPE_MAPPING.get(modelName);
|
|
@@ -830,7 +852,9 @@ export class PreTrainedModel extends Callable {
|
|
|
830
852
|
constructSessions(pretrained_model_name_or_path, {
|
|
831
853
|
model: options.model_file_name ?? 'model',
|
|
832
854
|
}, options),
|
|
833
|
-
|
|
855
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
856
|
+
generation_config: 'generation_config.json',
|
|
857
|
+
}, options),
|
|
834
858
|
]);
|
|
835
859
|
|
|
836
860
|
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
|
|
@@ -839,7 +863,9 @@ export class PreTrainedModel extends Callable {
|
|
|
839
863
|
model: 'encoder_model',
|
|
840
864
|
decoder_model_merged: 'decoder_model_merged',
|
|
841
865
|
}, options),
|
|
842
|
-
|
|
866
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
867
|
+
generation_config: 'generation_config.json',
|
|
868
|
+
}, options),
|
|
843
869
|
]);
|
|
844
870
|
|
|
845
871
|
} else if (modelType === MODEL_TYPES.MaskGeneration) {
|
|
@@ -869,7 +895,9 @@ export class PreTrainedModel extends Callable {
|
|
|
869
895
|
}
|
|
870
896
|
info = await Promise.all([
|
|
871
897
|
constructSessions(pretrained_model_name_or_path, sessions, options),
|
|
872
|
-
|
|
898
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
899
|
+
generation_config: 'generation_config.json',
|
|
900
|
+
}, options),
|
|
873
901
|
]);
|
|
874
902
|
|
|
875
903
|
} else if (modelType === MODEL_TYPES.Musicgen) {
|
|
@@ -879,12 +907,14 @@ export class PreTrainedModel extends Callable {
|
|
|
879
907
|
decoder_model_merged: 'decoder_model_merged',
|
|
880
908
|
encodec_decode: 'encodec_decode',
|
|
881
909
|
}, options),
|
|
882
|
-
|
|
910
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
911
|
+
generation_config: 'generation_config.json',
|
|
912
|
+
}, options),
|
|
883
913
|
]);
|
|
884
914
|
|
|
885
915
|
} else { // should be MODEL_TYPES.EncoderOnly
|
|
886
916
|
if (modelType !== MODEL_TYPES.EncoderOnly) {
|
|
887
|
-
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at
|
|
917
|
+
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
|
|
888
918
|
}
|
|
889
919
|
info = await Promise.all([
|
|
890
920
|
constructSessions(pretrained_model_name_or_path, {
|
|
@@ -917,6 +947,14 @@ export class PreTrainedModel extends Callable {
|
|
|
917
947
|
return await this._forward(this, model_inputs);
|
|
918
948
|
}
|
|
919
949
|
|
|
950
|
+
/**
|
|
951
|
+
* Get the model's generation config, if it exists.
|
|
952
|
+
* @returns {GenerationConfig|null} The model's generation config if it exists, otherwise `null`.
|
|
953
|
+
*/
|
|
954
|
+
get generation_config() {
|
|
955
|
+
return this.configs?.generation_config ?? null;
|
|
956
|
+
}
|
|
957
|
+
|
|
920
958
|
/**
|
|
921
959
|
* This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
|
|
922
960
|
* instances used for multinomial sampling.
|
|
@@ -1096,9 +1134,7 @@ export class PreTrainedModel extends Callable {
|
|
|
1096
1134
|
const gen_config = new cls(config);
|
|
1097
1135
|
|
|
1098
1136
|
// Apply model's generation config, if it exists
|
|
1099
|
-
|
|
1100
|
-
Object.assign(gen_config, this.generation_config);
|
|
1101
|
-
}
|
|
1137
|
+
Object.assign(gen_config, this.generation_config ?? {});
|
|
1102
1138
|
|
|
1103
1139
|
// Next, use any generation config specified by the user
|
|
1104
1140
|
// when calling `generate`
|
|
@@ -1298,35 +1334,37 @@ export class PreTrainedModel extends Callable {
|
|
|
1298
1334
|
let { decoder_input_ids, ...model_inputs } = model_kwargs;
|
|
1299
1335
|
|
|
1300
1336
|
// Prepare input ids if the user has not defined `decoder_input_ids` manually.
|
|
1301
|
-
if (!decoder_input_ids) {
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
if (decoder_start_token_id
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1337
|
+
if (!(decoder_input_ids instanceof Tensor)) {
|
|
1338
|
+
if (!decoder_input_ids) {
|
|
1339
|
+
decoder_start_token_id ??= bos_token_id;
|
|
1340
|
+
|
|
1341
|
+
if (this.config.model_type === 'musicgen') {
|
|
1342
|
+
// Custom logic (TODO: move to Musicgen class)
|
|
1343
|
+
decoder_input_ids = Array.from({
|
|
1344
|
+
length: batch_size * this.config.decoder.num_codebooks
|
|
1345
|
+
}, () => [decoder_start_token_id]);
|
|
1346
|
+
|
|
1347
|
+
} else if (Array.isArray(decoder_start_token_id)) {
|
|
1348
|
+
if (decoder_start_token_id.length !== batch_size) {
|
|
1349
|
+
throw new Error(
|
|
1350
|
+
`\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
|
|
1351
|
+
)
|
|
1352
|
+
}
|
|
1353
|
+
decoder_input_ids = decoder_start_token_id;
|
|
1354
|
+
} else {
|
|
1355
|
+
decoder_input_ids = Array.from({
|
|
1356
|
+
length: batch_size,
|
|
1357
|
+
}, () => [decoder_start_token_id]);
|
|
1315
1358
|
}
|
|
1316
|
-
|
|
1317
|
-
|
|
1359
|
+
} else if (!Array.isArray(decoder_input_ids[0])) {
|
|
1360
|
+
// Correct batch size
|
|
1318
1361
|
decoder_input_ids = Array.from({
|
|
1319
1362
|
length: batch_size,
|
|
1320
|
-
}, () =>
|
|
1363
|
+
}, () => decoder_input_ids);
|
|
1321
1364
|
}
|
|
1322
|
-
|
|
1323
|
-
// Correct batch size
|
|
1324
|
-
decoder_input_ids = Array.from({
|
|
1325
|
-
length: batch_size,
|
|
1326
|
-
}, () => decoder_input_ids);
|
|
1365
|
+
decoder_input_ids = toI64Tensor(decoder_input_ids);
|
|
1327
1366
|
}
|
|
1328
1367
|
|
|
1329
|
-
decoder_input_ids = toI64Tensor(decoder_input_ids);
|
|
1330
1368
|
model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids);
|
|
1331
1369
|
|
|
1332
1370
|
return { input_ids: decoder_input_ids, model_inputs };
|
|
@@ -1458,13 +1496,12 @@ export class PreTrainedModel extends Callable {
|
|
|
1458
1496
|
// - GenerationMode.BEAM_SEARCH
|
|
1459
1497
|
// - GenerationMode.BEAM_SAMPLE
|
|
1460
1498
|
////////////////////////////////////////////////////
|
|
1461
|
-
let
|
|
1499
|
+
let outputs;
|
|
1462
1500
|
let attentions = {};
|
|
1463
1501
|
while (true) {
|
|
1464
1502
|
// prepare model inputs
|
|
1465
1503
|
model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs, generation_config);
|
|
1466
|
-
|
|
1467
|
-
const outputs = await this.forward(model_inputs);
|
|
1504
|
+
outputs = await this.forward(model_inputs);
|
|
1468
1505
|
|
|
1469
1506
|
if (generation_config.output_attentions && generation_config.return_dict_in_generate) {
|
|
1470
1507
|
// Get attentions if they are present
|
|
@@ -1511,10 +1548,6 @@ export class PreTrainedModel extends Callable {
|
|
|
1511
1548
|
|
|
1512
1549
|
const stop = prepared_stopping_criteria(all_input_ids);
|
|
1513
1550
|
if (stop.every(x => x)) {
|
|
1514
|
-
if (generation_config.return_dict_in_generate) {
|
|
1515
|
-
// Get past key values without disposing buffers
|
|
1516
|
-
past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, false);
|
|
1517
|
-
}
|
|
1518
1551
|
break;
|
|
1519
1552
|
}
|
|
1520
1553
|
|
|
@@ -1527,6 +1560,9 @@ export class PreTrainedModel extends Callable {
|
|
|
1527
1560
|
streamer.end();
|
|
1528
1561
|
}
|
|
1529
1562
|
|
|
1563
|
+
// Retrieve and dispose all final past key values (including encoder attentions)
|
|
1564
|
+
const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true);
|
|
1565
|
+
|
|
1530
1566
|
// TODO: ensure all_input_ids is padded correctly...
|
|
1531
1567
|
const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]);
|
|
1532
1568
|
|
|
@@ -1540,6 +1576,12 @@ export class PreTrainedModel extends Callable {
|
|
|
1540
1576
|
// logits,
|
|
1541
1577
|
}
|
|
1542
1578
|
} else {
|
|
1579
|
+
// Dispose all remaining tensors
|
|
1580
|
+
for (const tensor of Object.values(outputs)) {
|
|
1581
|
+
if (tensor.location === 'gpu-buffer') {
|
|
1582
|
+
tensor.dispose();
|
|
1583
|
+
}
|
|
1584
|
+
}
|
|
1543
1585
|
return sequences;
|
|
1544
1586
|
}
|
|
1545
1587
|
}
|
|
@@ -1549,31 +1591,32 @@ export class PreTrainedModel extends Callable {
|
|
|
1549
1591
|
*
|
|
1550
1592
|
* @param {Object} decoderResults The decoder results object.
|
|
1551
1593
|
* @param {Object} pastKeyValues The previous past key values.
|
|
1552
|
-
* @param {boolean} [dispose=true] Whether to dispose of the old gpu buffer.
|
|
1553
1594
|
* @returns {Object} An object containing past key values.
|
|
1554
1595
|
*/
|
|
1555
|
-
getPastKeyValues(decoderResults, pastKeyValues,
|
|
1596
|
+
getPastKeyValues(decoderResults, pastKeyValues, disposeEncoderPKVs = false) {
|
|
1556
1597
|
const pkvs = Object.create(null);
|
|
1557
1598
|
|
|
1558
1599
|
for (const name in decoderResults) {
|
|
1559
1600
|
if (name.startsWith('present')) {
|
|
1560
1601
|
const newName = name.replace('present', 'past_key_values');
|
|
1561
|
-
|
|
1562
|
-
if (
|
|
1563
|
-
// Optimization introduced by optimum to reuse past key values.
|
|
1564
|
-
// outputs with the previous past key values.
|
|
1602
|
+
const is_encoder_pkv = name.includes('encoder');
|
|
1603
|
+
if (is_encoder_pkv && pastKeyValues) {
|
|
1604
|
+
// Optimization introduced by optimum to reuse past key values.
|
|
1605
|
+
// So, we just replace the constant outputs (`decoderResults[name]`) with the previous past key values.
|
|
1565
1606
|
// https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
|
|
1566
1607
|
pkvs[newName] = pastKeyValues[newName];
|
|
1567
|
-
} else {
|
|
1568
|
-
if (dispose && pastKeyValues) {
|
|
1569
|
-
// Free old gpu buffer
|
|
1570
|
-
const t = pastKeyValues[newName];
|
|
1571
|
-
if (t.location === 'gpu-buffer') {
|
|
1572
|
-
t.dispose();
|
|
1573
|
-
}
|
|
1574
|
-
}
|
|
1608
|
+
} else { // decoder or using first encoder PKVs
|
|
1575
1609
|
pkvs[newName] = decoderResults[name];
|
|
1576
1610
|
}
|
|
1611
|
+
|
|
1612
|
+
if (pastKeyValues && (!is_encoder_pkv || disposeEncoderPKVs)) {
|
|
1613
|
+
// - Always dispose decoder PKVs
|
|
1614
|
+
// - Only dispose encoder past key values when requested (after generation)
|
|
1615
|
+
const t = pastKeyValues[newName];
|
|
1616
|
+
if (t.location === 'gpu-buffer') {
|
|
1617
|
+
t.dispose();
|
|
1618
|
+
}
|
|
1619
|
+
}
|
|
1577
1620
|
}
|
|
1578
1621
|
}
|
|
1579
1622
|
return pkvs;
|
|
@@ -1611,9 +1654,8 @@ export class PreTrainedModel extends Callable {
|
|
|
1611
1654
|
if (pastKeyValues) {
|
|
1612
1655
|
Object.assign(decoderFeeds, pastKeyValues)
|
|
1613
1656
|
} else {
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
const dtype = this.custom_config.kv_cache_dtype ?? 'float32';
|
|
1657
|
+
const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
|
|
1658
|
+
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
|
|
1617
1659
|
const empty = (dtype === 'float16') ? new Uint16Array() : [];
|
|
1618
1660
|
|
|
1619
1661
|
const shapes = getKeyValueShapes(this.config);
|
|
@@ -2506,17 +2548,6 @@ export class T5PreTrainedModel extends PreTrainedModel {
|
|
|
2506
2548
|
'decoder_attention_mask',
|
|
2507
2549
|
'past_key_values',
|
|
2508
2550
|
];
|
|
2509
|
-
|
|
2510
|
-
/**
|
|
2511
|
-
* Creates a new instance of the `T5PreTrainedModel` class.
|
|
2512
|
-
* @param {Object} config The model configuration.
|
|
2513
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2514
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2515
|
-
*/
|
|
2516
|
-
constructor(config, sessions, generation_config) {
|
|
2517
|
-
super(config, sessions);
|
|
2518
|
-
this.generation_config = generation_config;
|
|
2519
|
-
}
|
|
2520
2551
|
};
|
|
2521
2552
|
|
|
2522
2553
|
export class T5Model extends T5PreTrainedModel { }
|
|
@@ -2533,18 +2564,7 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { }
|
|
|
2533
2564
|
/**
|
|
2534
2565
|
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
|
2535
2566
|
*/
|
|
2536
|
-
export class LongT5PreTrainedModel extends PreTrainedModel {
|
|
2537
|
-
/**
|
|
2538
|
-
* Creates a new instance of the `LongT5ForConditionalGeneration` class.
|
|
2539
|
-
* @param {Object} config The model configuration.
|
|
2540
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2541
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2542
|
-
*/
|
|
2543
|
-
constructor(config, sessions, generation_config) {
|
|
2544
|
-
super(config, sessions);
|
|
2545
|
-
this.generation_config = generation_config;
|
|
2546
|
-
}
|
|
2547
|
-
};
|
|
2567
|
+
export class LongT5PreTrainedModel extends PreTrainedModel { };
|
|
2548
2568
|
|
|
2549
2569
|
/**
|
|
2550
2570
|
* The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -2560,19 +2580,7 @@ export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { }
|
|
|
2560
2580
|
|
|
2561
2581
|
//////////////////////////////////////////////////
|
|
2562
2582
|
// MT5 models
|
|
2563
|
-
export class MT5PreTrainedModel extends PreTrainedModel {
|
|
2564
|
-
|
|
2565
|
-
/**
|
|
2566
|
-
* Creates a new instance of the `MT5ForConditionalGeneration` class.
|
|
2567
|
-
* @param {Object} config The model configuration.
|
|
2568
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2569
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2570
|
-
*/
|
|
2571
|
-
constructor(config, sessions, generation_config) {
|
|
2572
|
-
super(config, sessions);
|
|
2573
|
-
this.generation_config = generation_config;
|
|
2574
|
-
}
|
|
2575
|
-
};
|
|
2583
|
+
export class MT5PreTrainedModel extends PreTrainedModel { };
|
|
2576
2584
|
|
|
2577
2585
|
export class MT5Model extends MT5PreTrainedModel { }
|
|
2578
2586
|
|
|
@@ -2584,19 +2592,7 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { }
|
|
|
2584
2592
|
|
|
2585
2593
|
//////////////////////////////////////////////////
|
|
2586
2594
|
// Bart models
|
|
2587
|
-
export class BartPretrainedModel extends PreTrainedModel {
|
|
2588
|
-
|
|
2589
|
-
/**
|
|
2590
|
-
* Creates a new instance of the `BartForConditionalGeneration` class.
|
|
2591
|
-
* @param {Object} config The model configuration.
|
|
2592
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2593
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2594
|
-
*/
|
|
2595
|
-
constructor(config, sessions, generation_config) {
|
|
2596
|
-
super(config, sessions);
|
|
2597
|
-
this.generation_config = generation_config;
|
|
2598
|
-
}
|
|
2599
|
-
};
|
|
2595
|
+
export class BartPretrainedModel extends PreTrainedModel { };
|
|
2600
2596
|
|
|
2601
2597
|
/**
|
|
2602
2598
|
* The bare BART Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2627,19 +2623,7 @@ export class BartForSequenceClassification extends BartPretrainedModel {
|
|
|
2627
2623
|
|
|
2628
2624
|
//////////////////////////////////////////////////
|
|
2629
2625
|
// MBart models
|
|
2630
|
-
export class MBartPreTrainedModel extends PreTrainedModel {
|
|
2631
|
-
|
|
2632
|
-
/**
|
|
2633
|
-
* Creates a new instance of the `MBartForConditionalGeneration` class.
|
|
2634
|
-
* @param {Object} config The model configuration.
|
|
2635
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2636
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2637
|
-
*/
|
|
2638
|
-
constructor(config, sessions, generation_config) {
|
|
2639
|
-
super(config, sessions);
|
|
2640
|
-
this.generation_config = generation_config;
|
|
2641
|
-
}
|
|
2642
|
-
};
|
|
2626
|
+
export class MBartPreTrainedModel extends PreTrainedModel { };
|
|
2643
2627
|
|
|
2644
2628
|
/**
|
|
2645
2629
|
* The bare MBART Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2673,19 +2657,7 @@ export class MBartForCausalLM extends MBartPreTrainedModel { }
|
|
|
2673
2657
|
|
|
2674
2658
|
//////////////////////////////////////////////////
|
|
2675
2659
|
// Blenderbot models
|
|
2676
|
-
export class BlenderbotPreTrainedModel extends PreTrainedModel {
|
|
2677
|
-
|
|
2678
|
-
/**
|
|
2679
|
-
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
|
|
2680
|
-
* @param {Object} config The model configuration.
|
|
2681
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2682
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2683
|
-
*/
|
|
2684
|
-
constructor(config, sessions, generation_config) {
|
|
2685
|
-
super(config, sessions);
|
|
2686
|
-
this.generation_config = generation_config;
|
|
2687
|
-
}
|
|
2688
|
-
};
|
|
2660
|
+
export class BlenderbotPreTrainedModel extends PreTrainedModel { };
|
|
2689
2661
|
|
|
2690
2662
|
/**
|
|
2691
2663
|
* The bare Blenderbot Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2701,19 +2673,7 @@ export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedMode
|
|
|
2701
2673
|
|
|
2702
2674
|
//////////////////////////////////////////////////
|
|
2703
2675
|
// Blenderbot models
|
|
2704
|
-
export class BlenderbotSmallPreTrainedModel extends PreTrainedModel {
|
|
2705
|
-
|
|
2706
|
-
/**
|
|
2707
|
-
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
|
|
2708
|
-
* @param {Object} config The model configuration.
|
|
2709
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2710
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2711
|
-
*/
|
|
2712
|
-
constructor(config, sessions, generation_config) {
|
|
2713
|
-
super(config, sessions);
|
|
2714
|
-
this.generation_config = generation_config;
|
|
2715
|
-
}
|
|
2716
|
-
};
|
|
2676
|
+
export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { };
|
|
2717
2677
|
|
|
2718
2678
|
/**
|
|
2719
2679
|
* The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2962,17 +2922,6 @@ export class WhisperPreTrainedModel extends PreTrainedModel {
|
|
|
2962
2922
|
'decoder_attention_mask',
|
|
2963
2923
|
'past_key_values',
|
|
2964
2924
|
];
|
|
2965
|
-
|
|
2966
|
-
/**
|
|
2967
|
-
* Creates a new instance of the `WhisperPreTrainedModel` class.
|
|
2968
|
-
* @param {Object} config The model configuration.
|
|
2969
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2970
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2971
|
-
*/
|
|
2972
|
-
constructor(config, sessions, generation_config) {
|
|
2973
|
-
super(config, sessions);
|
|
2974
|
-
this.generation_config = generation_config;
|
|
2975
|
-
}
|
|
2976
2925
|
};
|
|
2977
2926
|
|
|
2978
2927
|
/**
|
|
@@ -3238,21 +3187,14 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
|
|
|
3238
3187
|
export class VisionEncoderDecoderModel extends PreTrainedModel {
|
|
3239
3188
|
main_input_name = 'pixel_values';
|
|
3240
3189
|
forward_params = [
|
|
3190
|
+
// Encoder inputs
|
|
3241
3191
|
'pixel_values',
|
|
3242
|
-
|
|
3192
|
+
|
|
3193
|
+
// Decoder inpputs
|
|
3194
|
+
'decoder_input_ids',
|
|
3243
3195
|
'encoder_hidden_states',
|
|
3244
3196
|
'past_key_values',
|
|
3245
3197
|
];
|
|
3246
|
-
/**
|
|
3247
|
-
* Creates a new instance of the `VisionEncoderDecoderModel` class.
|
|
3248
|
-
* @param {Object} config The model configuration.
|
|
3249
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3250
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3251
|
-
*/
|
|
3252
|
-
constructor(config, sessions, generation_config) {
|
|
3253
|
-
super(config, sessions);
|
|
3254
|
-
this.generation_config = generation_config;
|
|
3255
|
-
}
|
|
3256
3198
|
}
|
|
3257
3199
|
//////////////////////////////////////////////////
|
|
3258
3200
|
|
|
@@ -3267,11 +3209,6 @@ export class LlavaPreTrainedModel extends PreTrainedModel {
|
|
|
3267
3209
|
'position_ids',
|
|
3268
3210
|
'past_key_values',
|
|
3269
3211
|
];
|
|
3270
|
-
|
|
3271
|
-
constructor(config, sessions, generation_config) {
|
|
3272
|
-
super(config, sessions);
|
|
3273
|
-
this.generation_config = generation_config;
|
|
3274
|
-
}
|
|
3275
3212
|
}
|
|
3276
3213
|
|
|
3277
3214
|
/**
|
|
@@ -3358,11 +3295,6 @@ export class Florence2PreTrainedModel extends PreTrainedModel {
|
|
|
3358
3295
|
'past_key_values',
|
|
3359
3296
|
];
|
|
3360
3297
|
main_input_name = 'inputs_embeds';
|
|
3361
|
-
|
|
3362
|
-
constructor(config, sessions, generation_config) {
|
|
3363
|
-
super(config, sessions);
|
|
3364
|
-
this.generation_config = generation_config;
|
|
3365
|
-
}
|
|
3366
3298
|
}
|
|
3367
3299
|
|
|
3368
3300
|
export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel {
|
|
@@ -3501,6 +3433,18 @@ export class CLIPPreTrainedModel extends PreTrainedModel { }
|
|
|
3501
3433
|
*/
|
|
3502
3434
|
export class CLIPModel extends CLIPPreTrainedModel { }
|
|
3503
3435
|
|
|
3436
|
+
/**
|
|
3437
|
+
* The text model from CLIP without any head or projection on top.
|
|
3438
|
+
*/
|
|
3439
|
+
export class CLIPTextModel extends CLIPPreTrainedModel {
|
|
3440
|
+
/** @type {PreTrainedModel.from_pretrained} */
|
|
3441
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3442
|
+
// Update default model file name if not provided
|
|
3443
|
+
options.model_file_name ??= 'text_model';
|
|
3444
|
+
return super.from_pretrained(pretrained_model_name_or_path, options);
|
|
3445
|
+
}
|
|
3446
|
+
}
|
|
3447
|
+
|
|
3504
3448
|
/**
|
|
3505
3449
|
* CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output)
|
|
3506
3450
|
*
|
|
@@ -3528,7 +3472,6 @@ export class CLIPModel extends CLIPPreTrainedModel { }
|
|
|
3528
3472
|
* ```
|
|
3529
3473
|
*/
|
|
3530
3474
|
export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
|
|
3531
|
-
|
|
3532
3475
|
/** @type {PreTrainedModel.from_pretrained} */
|
|
3533
3476
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3534
3477
|
// Update default model file name if not provided
|
|
@@ -3537,6 +3480,18 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
|
|
|
3537
3480
|
}
|
|
3538
3481
|
}
|
|
3539
3482
|
|
|
3483
|
+
/**
|
|
3484
|
+
* The vision model from CLIP without any head or projection on top.
|
|
3485
|
+
*/
|
|
3486
|
+
export class CLIPVisionModel extends CLIPPreTrainedModel {
|
|
3487
|
+
/** @type {PreTrainedModel.from_pretrained} */
|
|
3488
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3489
|
+
// Update default model file name if not provided
|
|
3490
|
+
options.model_file_name ??= 'vision_model';
|
|
3491
|
+
return super.from_pretrained(pretrained_model_name_or_path, options);
|
|
3492
|
+
}
|
|
3493
|
+
}
|
|
3494
|
+
|
|
3540
3495
|
/**
|
|
3541
3496
|
* CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output)
|
|
3542
3497
|
*
|
|
@@ -3759,18 +3714,7 @@ export class CLIPSegForImageSegmentation extends CLIPSegPreTrainedModel { }
|
|
|
3759
3714
|
|
|
3760
3715
|
//////////////////////////////////////////////////
|
|
3761
3716
|
// GPT2 models
|
|
3762
|
-
export class GPT2PreTrainedModel extends PreTrainedModel {
|
|
3763
|
-
/**
|
|
3764
|
-
* Creates a new instance of the `GPT2PreTrainedModel` class.
|
|
3765
|
-
* @param {Object} config The model configuration.
|
|
3766
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3767
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3768
|
-
*/
|
|
3769
|
-
constructor(config, sessions, generation_config) {
|
|
3770
|
-
super(config, sessions);
|
|
3771
|
-
this.generation_config = generation_config;
|
|
3772
|
-
}
|
|
3773
|
-
}
|
|
3717
|
+
export class GPT2PreTrainedModel extends PreTrainedModel { }
|
|
3774
3718
|
|
|
3775
3719
|
export class GPT2Model extends GPT2PreTrainedModel { }
|
|
3776
3720
|
|
|
@@ -3783,20 +3727,25 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { }
|
|
|
3783
3727
|
// }
|
|
3784
3728
|
//////////////////////////////////////////////////
|
|
3785
3729
|
|
|
3730
|
+
//////////////////////////////////////////////////
|
|
3731
|
+
// JAIS models
|
|
3732
|
+
export class JAISPreTrainedModel extends PreTrainedModel { }
|
|
3733
|
+
|
|
3734
|
+
/**
|
|
3735
|
+
* The bare JAIS Model transformer outputting raw hidden-states without any specific head on top.
|
|
3736
|
+
*/
|
|
3737
|
+
export class JAISModel extends JAISPreTrainedModel { }
|
|
3738
|
+
|
|
3739
|
+
/**
|
|
3740
|
+
* The JAIS Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
|
3741
|
+
*/
|
|
3742
|
+
export class JAISLMHeadModel extends JAISPreTrainedModel { }
|
|
3743
|
+
//////////////////////////////////////////////////
|
|
3744
|
+
|
|
3745
|
+
|
|
3786
3746
|
//////////////////////////////////////////////////
|
|
3787
3747
|
// GPTNeo models
|
|
3788
|
-
export class GPTNeoPreTrainedModel extends PreTrainedModel {
|
|
3789
|
-
/**
|
|
3790
|
-
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
|
|
3791
|
-
* @param {Object} config The model configuration.
|
|
3792
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3793
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3794
|
-
*/
|
|
3795
|
-
constructor(config, sessions, generation_config) {
|
|
3796
|
-
super(config, sessions);
|
|
3797
|
-
this.generation_config = generation_config;
|
|
3798
|
-
}
|
|
3799
|
-
}
|
|
3748
|
+
export class GPTNeoPreTrainedModel extends PreTrainedModel { }
|
|
3800
3749
|
export class GPTNeoModel extends GPTNeoPreTrainedModel { }
|
|
3801
3750
|
|
|
3802
3751
|
export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
|
|
@@ -3804,18 +3753,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
|
|
|
3804
3753
|
|
|
3805
3754
|
//////////////////////////////////////////////////
|
|
3806
3755
|
// GPTNeoX models
|
|
3807
|
-
export class GPTNeoXPreTrainedModel extends PreTrainedModel {
|
|
3808
|
-
/**
|
|
3809
|
-
* Creates a new instance of the `GPTNeoXPreTrainedModel` class.
|
|
3810
|
-
* @param {Object} config The model configuration.
|
|
3811
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3812
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3813
|
-
*/
|
|
3814
|
-
constructor(config, sessions, generation_config) {
|
|
3815
|
-
super(config, sessions);
|
|
3816
|
-
this.generation_config = generation_config;
|
|
3817
|
-
}
|
|
3818
|
-
}
|
|
3756
|
+
export class GPTNeoXPreTrainedModel extends PreTrainedModel { }
|
|
3819
3757
|
export class GPTNeoXModel extends GPTNeoXPreTrainedModel { }
|
|
3820
3758
|
|
|
3821
3759
|
export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
|
|
@@ -3824,18 +3762,7 @@ export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
|
|
|
3824
3762
|
|
|
3825
3763
|
//////////////////////////////////////////////////
|
|
3826
3764
|
// GPT-J models
|
|
3827
|
-
export class GPTJPreTrainedModel extends PreTrainedModel {
|
|
3828
|
-
/**
|
|
3829
|
-
* Creates a new instance of the `GPTJPreTrainedModel` class.
|
|
3830
|
-
* @param {Object} config The model configuration.
|
|
3831
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3832
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3833
|
-
*/
|
|
3834
|
-
constructor(config, sessions, generation_config) {
|
|
3835
|
-
super(config, sessions);
|
|
3836
|
-
this.generation_config = generation_config;
|
|
3837
|
-
}
|
|
3838
|
-
}
|
|
3765
|
+
export class GPTJPreTrainedModel extends PreTrainedModel { }
|
|
3839
3766
|
|
|
3840
3767
|
export class GPTJModel extends GPTJPreTrainedModel { }
|
|
3841
3768
|
|
|
@@ -3845,18 +3772,7 @@ export class GPTJForCausalLM extends GPTJPreTrainedModel { }
|
|
|
3845
3772
|
|
|
3846
3773
|
//////////////////////////////////////////////////
|
|
3847
3774
|
// GPTBigCode models
|
|
3848
|
-
export class GPTBigCodePreTrainedModel extends PreTrainedModel {
|
|
3849
|
-
/**
|
|
3850
|
-
* Creates a new instance of the `GPTBigCodePreTrainedModel` class.
|
|
3851
|
-
* @param {Object} config The model configuration.
|
|
3852
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3853
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3854
|
-
*/
|
|
3855
|
-
constructor(config, sessions, generation_config) {
|
|
3856
|
-
super(config, sessions);
|
|
3857
|
-
this.generation_config = generation_config;
|
|
3858
|
-
}
|
|
3859
|
-
}
|
|
3775
|
+
export class GPTBigCodePreTrainedModel extends PreTrainedModel { }
|
|
3860
3776
|
|
|
3861
3777
|
export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { }
|
|
3862
3778
|
|
|
@@ -3865,18 +3781,7 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
|
|
|
3865
3781
|
|
|
3866
3782
|
//////////////////////////////////////////////////
|
|
3867
3783
|
// CodeGen models
|
|
3868
|
-
export class CodeGenPreTrainedModel extends PreTrainedModel {
|
|
3869
|
-
/**
|
|
3870
|
-
* Creates a new instance of the `CodeGenPreTrainedModel` class.
|
|
3871
|
-
* @param {Object} config The model configuration.
|
|
3872
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3873
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3874
|
-
*/
|
|
3875
|
-
constructor(config, sessions, generation_config) {
|
|
3876
|
-
super(config, sessions);
|
|
3877
|
-
this.generation_config = generation_config;
|
|
3878
|
-
}
|
|
3879
|
-
}
|
|
3784
|
+
export class CodeGenPreTrainedModel extends PreTrainedModel { }
|
|
3880
3785
|
/**
|
|
3881
3786
|
* CodeGenModel is a class representing a code generation model without a language model head.
|
|
3882
3787
|
*/
|
|
@@ -3895,18 +3800,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
|
|
|
3895
3800
|
/**
|
|
3896
3801
|
* The bare LLama Model outputting raw hidden-states without any specific head on top.
|
|
3897
3802
|
*/
|
|
3898
|
-
export class LlamaPreTrainedModel extends PreTrainedModel {
|
|
3899
|
-
/**
|
|
3900
|
-
* Creates a new instance of the `LlamaPreTrainedModel` class.
|
|
3901
|
-
* @param {Object} config The model configuration.
|
|
3902
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3903
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3904
|
-
*/
|
|
3905
|
-
constructor(config, sessions, generation_config) {
|
|
3906
|
-
super(config, sessions);
|
|
3907
|
-
this.generation_config = generation_config;
|
|
3908
|
-
}
|
|
3909
|
-
}
|
|
3803
|
+
export class LlamaPreTrainedModel extends PreTrainedModel { }
|
|
3910
3804
|
/**
|
|
3911
3805
|
* The bare LLaMA Model outputting raw hidden-states without any specific head on top.
|
|
3912
3806
|
*/
|
|
@@ -3915,24 +3809,22 @@ export class LlamaModel extends LlamaPreTrainedModel { }
|
|
|
3915
3809
|
export class LlamaForCausalLM extends LlamaPreTrainedModel { }
|
|
3916
3810
|
//////////////////////////////////////////////////
|
|
3917
3811
|
|
|
3812
|
+
|
|
3813
|
+
//////////////////////////////////////////////////
|
|
3814
|
+
// Granite models
|
|
3815
|
+
export class GranitePreTrainedModel extends PreTrainedModel { }
|
|
3816
|
+
export class GraniteModel extends GranitePreTrainedModel { }
|
|
3817
|
+
export class GraniteForCausalLM extends GranitePreTrainedModel { }
|
|
3818
|
+
//////////////////////////////////////////////////
|
|
3819
|
+
|
|
3820
|
+
|
|
3918
3821
|
//////////////////////////////////////////////////
|
|
3919
3822
|
// Cohere models
|
|
3920
3823
|
|
|
3921
3824
|
/**
|
|
3922
3825
|
* The bare Cohere Model outputting raw hidden-states without any specific head on top.
|
|
3923
3826
|
*/
|
|
3924
|
-
export class CoherePreTrainedModel extends PreTrainedModel {
|
|
3925
|
-
/**
|
|
3926
|
-
* Creates a new instance of the `CoherePreTrainedModel` class.
|
|
3927
|
-
* @param {Object} config The model configuration.
|
|
3928
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3929
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3930
|
-
*/
|
|
3931
|
-
constructor(config, sessions, generation_config) {
|
|
3932
|
-
super(config, sessions);
|
|
3933
|
-
this.generation_config = generation_config;
|
|
3934
|
-
}
|
|
3935
|
-
}
|
|
3827
|
+
export class CoherePreTrainedModel extends PreTrainedModel { }
|
|
3936
3828
|
export class CohereModel extends CoherePreTrainedModel { }
|
|
3937
3829
|
|
|
3938
3830
|
export class CohereForCausalLM extends CoherePreTrainedModel { }
|
|
@@ -3944,18 +3836,7 @@ export class CohereForCausalLM extends CoherePreTrainedModel { }
|
|
|
3944
3836
|
/**
|
|
3945
3837
|
* The bare Gemma Model outputting raw hidden-states without any specific head on top.
|
|
3946
3838
|
*/
|
|
3947
|
-
export class GemmaPreTrainedModel extends PreTrainedModel {
|
|
3948
|
-
/**
|
|
3949
|
-
* Creates a new instance of the `GemmaPreTrainedModel` class.
|
|
3950
|
-
* @param {Object} config The model configuration.
|
|
3951
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3952
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3953
|
-
*/
|
|
3954
|
-
constructor(config, sessions, generation_config) {
|
|
3955
|
-
super(config, sessions);
|
|
3956
|
-
this.generation_config = generation_config;
|
|
3957
|
-
}
|
|
3958
|
-
}
|
|
3839
|
+
export class GemmaPreTrainedModel extends PreTrainedModel { }
|
|
3959
3840
|
/**
|
|
3960
3841
|
* The bare Gemma Model outputting raw hidden-states without any specific head on top.
|
|
3961
3842
|
*/
|
|
@@ -3970,18 +3851,7 @@ export class GemmaForCausalLM extends GemmaPreTrainedModel { }
|
|
|
3970
3851
|
/**
|
|
3971
3852
|
* The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
|
|
3972
3853
|
*/
|
|
3973
|
-
export class Gemma2PreTrainedModel extends PreTrainedModel {
|
|
3974
|
-
/**
|
|
3975
|
-
* Creates a new instance of the `Gemma2PreTrainedModel` class.
|
|
3976
|
-
* @param {Object} config The model configuration.
|
|
3977
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3978
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3979
|
-
*/
|
|
3980
|
-
constructor(config, sessions, generation_config) {
|
|
3981
|
-
super(config, sessions);
|
|
3982
|
-
this.generation_config = generation_config;
|
|
3983
|
-
}
|
|
3984
|
-
}
|
|
3854
|
+
export class Gemma2PreTrainedModel extends PreTrainedModel { }
|
|
3985
3855
|
/**
|
|
3986
3856
|
* The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
|
|
3987
3857
|
*/
|
|
@@ -3991,18 +3861,7 @@ export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
|
|
|
3991
3861
|
//////////////////////////////////////////////////
|
|
3992
3862
|
|
|
3993
3863
|
//////////////////////////////////////////////////
|
|
3994
|
-
export class OpenELMPreTrainedModel extends PreTrainedModel {
|
|
3995
|
-
/**
|
|
3996
|
-
* Creates a new instance of the `OpenELMPreTrainedModel` class.
|
|
3997
|
-
* @param {Object} config The model configuration.
|
|
3998
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3999
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4000
|
-
*/
|
|
4001
|
-
constructor(config, sessions, generation_config) {
|
|
4002
|
-
super(config, sessions);
|
|
4003
|
-
this.generation_config = generation_config;
|
|
4004
|
-
}
|
|
4005
|
-
}
|
|
3864
|
+
export class OpenELMPreTrainedModel extends PreTrainedModel { }
|
|
4006
3865
|
export class OpenELMModel extends OpenELMPreTrainedModel { }
|
|
4007
3866
|
|
|
4008
3867
|
export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
|
|
@@ -4014,18 +3873,7 @@ export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
|
|
|
4014
3873
|
/**
|
|
4015
3874
|
* The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
|
|
4016
3875
|
*/
|
|
4017
|
-
export class Qwen2PreTrainedModel extends PreTrainedModel {
|
|
4018
|
-
/**
|
|
4019
|
-
* Creates a new instance of the `Qwen2PreTrainedModel` class.
|
|
4020
|
-
* @param {Object} config The model configuration.
|
|
4021
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4022
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4023
|
-
*/
|
|
4024
|
-
constructor(config, sessions, generation_config) {
|
|
4025
|
-
super(config, sessions);
|
|
4026
|
-
this.generation_config = generation_config;
|
|
4027
|
-
}
|
|
4028
|
-
}
|
|
3876
|
+
export class Qwen2PreTrainedModel extends PreTrainedModel { }
|
|
4029
3877
|
/**
|
|
4030
3878
|
* The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
|
|
4031
3879
|
*/
|
|
@@ -4037,18 +3885,7 @@ export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
|
|
|
4037
3885
|
|
|
4038
3886
|
//////////////////////////////////////////////////
|
|
4039
3887
|
// Phi models
|
|
4040
|
-
export class PhiPreTrainedModel extends PreTrainedModel {
|
|
4041
|
-
/**
|
|
4042
|
-
* Creates a new instance of the `PhiPreTrainedModel` class.
|
|
4043
|
-
* @param {Object} config The model configuration.
|
|
4044
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4045
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4046
|
-
*/
|
|
4047
|
-
constructor(config, sessions, generation_config) {
|
|
4048
|
-
super(config, sessions);
|
|
4049
|
-
this.generation_config = generation_config;
|
|
4050
|
-
}
|
|
4051
|
-
}
|
|
3888
|
+
export class PhiPreTrainedModel extends PreTrainedModel { }
|
|
4052
3889
|
/**
|
|
4053
3890
|
* The bare Phi Model outputting raw hidden-states without any specific head on top.
|
|
4054
3891
|
*/
|
|
@@ -4059,18 +3896,7 @@ export class PhiForCausalLM extends PhiPreTrainedModel { }
|
|
|
4059
3896
|
|
|
4060
3897
|
//////////////////////////////////////////////////
|
|
4061
3898
|
// Phi3 models
|
|
4062
|
-
export class Phi3PreTrainedModel extends PreTrainedModel {
|
|
4063
|
-
/**
|
|
4064
|
-
* Creates a new instance of the `Phi3PreTrainedModel` class.
|
|
4065
|
-
* @param {Object} config The model configuration.
|
|
4066
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4067
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4068
|
-
*/
|
|
4069
|
-
constructor(config, sessions, generation_config) {
|
|
4070
|
-
super(config, sessions);
|
|
4071
|
-
this.generation_config = generation_config;
|
|
4072
|
-
}
|
|
4073
|
-
}
|
|
3899
|
+
export class Phi3PreTrainedModel extends PreTrainedModel { }
|
|
4074
3900
|
|
|
4075
3901
|
/**
|
|
4076
3902
|
* The bare Phi3 Model outputting raw hidden-states without any specific head on top.
|
|
@@ -4086,18 +3912,7 @@ export class Phi3ForCausalLM extends Phi3PreTrainedModel { }
|
|
|
4086
3912
|
/**
|
|
4087
3913
|
* The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
|
4088
3914
|
*/
|
|
4089
|
-
export class BloomPreTrainedModel extends PreTrainedModel {
|
|
4090
|
-
/**
|
|
4091
|
-
* Creates a new instance of the `BloomPreTrainedModel` class.
|
|
4092
|
-
* @param {Object} config The model configuration.
|
|
4093
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4094
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4095
|
-
*/
|
|
4096
|
-
constructor(config, sessions, generation_config) {
|
|
4097
|
-
super(config, sessions);
|
|
4098
|
-
this.generation_config = generation_config;
|
|
4099
|
-
}
|
|
4100
|
-
}
|
|
3915
|
+
export class BloomPreTrainedModel extends PreTrainedModel { }
|
|
4101
3916
|
|
|
4102
3917
|
/**
|
|
4103
3918
|
* The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -4112,18 +3927,7 @@ export class BloomForCausalLM extends BloomPreTrainedModel { }
|
|
|
4112
3927
|
|
|
4113
3928
|
//////////////////////////////////////////////////
|
|
4114
3929
|
// MPT models
|
|
4115
|
-
export class MptPreTrainedModel extends PreTrainedModel {
|
|
4116
|
-
/**
|
|
4117
|
-
* Creates a new instance of the `MptPreTrainedModel` class.
|
|
4118
|
-
* @param {Object} config The model configuration.
|
|
4119
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4120
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4121
|
-
*/
|
|
4122
|
-
constructor(config, sessions, generation_config) {
|
|
4123
|
-
super(config, sessions);
|
|
4124
|
-
this.generation_config = generation_config;
|
|
4125
|
-
}
|
|
4126
|
-
}
|
|
3930
|
+
export class MptPreTrainedModel extends PreTrainedModel { }
|
|
4127
3931
|
|
|
4128
3932
|
/**
|
|
4129
3933
|
* The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -4139,18 +3943,7 @@ export class MptForCausalLM extends MptPreTrainedModel { }
|
|
|
4139
3943
|
|
|
4140
3944
|
//////////////////////////////////////////////////
|
|
4141
3945
|
// OPT models
|
|
4142
|
-
export class OPTPreTrainedModel extends PreTrainedModel {
|
|
4143
|
-
/**
|
|
4144
|
-
* Creates a new instance of the `OPTPreTrainedModel` class.
|
|
4145
|
-
* @param {Object} config The model configuration.
|
|
4146
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4147
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4148
|
-
*/
|
|
4149
|
-
constructor(config, sessions, generation_config) {
|
|
4150
|
-
super(config, sessions);
|
|
4151
|
-
this.generation_config = generation_config;
|
|
4152
|
-
}
|
|
4153
|
-
}
|
|
3946
|
+
export class OPTPreTrainedModel extends PreTrainedModel { }
|
|
4154
3947
|
|
|
4155
3948
|
/**
|
|
4156
3949
|
* The bare OPT Model outputting raw hidden-states without any specific head on top.
|
|
@@ -4176,6 +3969,43 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
|
|
|
4176
3969
|
}
|
|
4177
3970
|
//////////////////////////////////////////////////
|
|
4178
3971
|
|
|
3972
|
+
//////////////////////////////////////////////////
|
|
3973
|
+
export class PvtPreTrainedModel extends PreTrainedModel { }
|
|
3974
|
+
export class PvtModel extends PvtPreTrainedModel { }
|
|
3975
|
+
export class PvtForImageClassification extends PvtPreTrainedModel {
|
|
3976
|
+
/**
|
|
3977
|
+
* @param {any} model_inputs
|
|
3978
|
+
*/
|
|
3979
|
+
async _call(model_inputs) {
|
|
3980
|
+
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
3981
|
+
}
|
|
3982
|
+
}
|
|
3983
|
+
//////////////////////////////////////////////////
|
|
3984
|
+
|
|
3985
|
+
//////////////////////////////////////////////////
|
|
3986
|
+
export class ViTMAEPreTrainedModel extends PreTrainedModel { }
|
|
3987
|
+
export class ViTMAEModel extends ViTMAEPreTrainedModel { }
|
|
3988
|
+
//////////////////////////////////////////////////
|
|
3989
|
+
|
|
3990
|
+
|
|
3991
|
+
//////////////////////////////////////////////////
|
|
3992
|
+
export class ViTMSNPreTrainedModel extends PreTrainedModel { }
|
|
3993
|
+
export class ViTMSNModel extends ViTMSNPreTrainedModel { }
|
|
3994
|
+
export class ViTMSNForImageClassification extends ViTMSNPreTrainedModel {
|
|
3995
|
+
/**
|
|
3996
|
+
* @param {any} model_inputs
|
|
3997
|
+
*/
|
|
3998
|
+
async _call(model_inputs) {
|
|
3999
|
+
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
4000
|
+
}
|
|
4001
|
+
}
|
|
4002
|
+
//////////////////////////////////////////////////
|
|
4003
|
+
|
|
4004
|
+
//////////////////////////////////////////////////
|
|
4005
|
+
export class GroupViTPreTrainedModel extends PreTrainedModel { }
|
|
4006
|
+
export class GroupViTModel extends GroupViTPreTrainedModel { }
|
|
4007
|
+
//////////////////////////////////////////////////
|
|
4008
|
+
|
|
4179
4009
|
|
|
4180
4010
|
//////////////////////////////////////////////////
|
|
4181
4011
|
export class FastViTPreTrainedModel extends PreTrainedModel { }
|
|
@@ -4429,6 +4259,19 @@ export class DeiTForImageClassification extends DeiTPreTrainedModel {
|
|
|
4429
4259
|
}
|
|
4430
4260
|
//////////////////////////////////////////////////
|
|
4431
4261
|
|
|
4262
|
+
//////////////////////////////////////////////////
|
|
4263
|
+
export class HieraPreTrainedModel extends PreTrainedModel { }
|
|
4264
|
+
export class HieraModel extends HieraPreTrainedModel { }
|
|
4265
|
+
export class HieraForImageClassification extends HieraPreTrainedModel {
|
|
4266
|
+
/**
|
|
4267
|
+
* @param {any} model_inputs
|
|
4268
|
+
*/
|
|
4269
|
+
async _call(model_inputs) {
|
|
4270
|
+
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
4271
|
+
}
|
|
4272
|
+
}
|
|
4273
|
+
//////////////////////////////////////////////////
|
|
4274
|
+
|
|
4432
4275
|
|
|
4433
4276
|
//////////////////////////////////////////////////
|
|
4434
4277
|
/**
|
|
@@ -4568,6 +4411,24 @@ export class DepthAnythingForDepthEstimation extends DepthAnythingPreTrainedMode
|
|
|
4568
4411
|
//////////////////////////////////////////////////
|
|
4569
4412
|
|
|
4570
4413
|
|
|
4414
|
+
//////////////////////////////////////////////////
|
|
4415
|
+
export class SapiensPreTrainedModel extends PreTrainedModel { }
|
|
4416
|
+
export class SapiensForSemanticSegmentation extends SapiensPreTrainedModel { }
|
|
4417
|
+
export class SapiensForDepthEstimation extends SapiensPreTrainedModel { }
|
|
4418
|
+
export class SapiensForNormalEstimation extends SapiensPreTrainedModel { }
|
|
4419
|
+
//////////////////////////////////////////////////
|
|
4420
|
+
|
|
4421
|
+
//////////////////////////////////////////////////
|
|
4422
|
+
export class DepthProPreTrainedModel extends PreTrainedModel { }
|
|
4423
|
+
export class DepthProForDepthEstimation extends DepthProPreTrainedModel { }
|
|
4424
|
+
//////////////////////////////////////////////////
|
|
4425
|
+
|
|
4426
|
+
//////////////////////////////////////////////////
|
|
4427
|
+
export class MaskFormerPreTrainedModel extends PreTrainedModel { }
|
|
4428
|
+
export class MaskFormerModel extends MaskFormerPreTrainedModel { }
|
|
4429
|
+
export class MaskFormerForInstanceSegmentation extends MaskFormerPreTrainedModel { }
|
|
4430
|
+
//////////////////////////////////////////////////
|
|
4431
|
+
|
|
4571
4432
|
//////////////////////////////////////////////////
|
|
4572
4433
|
export class GLPNPreTrainedModel extends PreTrainedModel { }
|
|
4573
4434
|
|
|
@@ -4944,19 +4805,7 @@ export class SamImageSegmentationOutput extends ModelOutput {
|
|
|
4944
4805
|
|
|
4945
4806
|
//////////////////////////////////////////////////
|
|
4946
4807
|
// MarianMT models
|
|
4947
|
-
export class MarianPreTrainedModel extends PreTrainedModel {
|
|
4948
|
-
|
|
4949
|
-
/**
|
|
4950
|
-
* Creates a new instance of the `MarianMTModel` class.
|
|
4951
|
-
* @param {Object} config The model configuration.
|
|
4952
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4953
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4954
|
-
*/
|
|
4955
|
-
constructor(config, sessions, generation_config) {
|
|
4956
|
-
super(config, sessions);
|
|
4957
|
-
this.generation_config = generation_config;
|
|
4958
|
-
}
|
|
4959
|
-
};
|
|
4808
|
+
export class MarianPreTrainedModel extends PreTrainedModel { };
|
|
4960
4809
|
|
|
4961
4810
|
export class MarianModel extends MarianPreTrainedModel { }
|
|
4962
4811
|
|
|
@@ -4965,19 +4814,7 @@ export class MarianMTModel extends MarianPreTrainedModel { }
|
|
|
4965
4814
|
|
|
4966
4815
|
//////////////////////////////////////////////////
|
|
4967
4816
|
// M2M100 models
|
|
4968
|
-
export class M2M100PreTrainedModel extends PreTrainedModel {
|
|
4969
|
-
|
|
4970
|
-
/**
|
|
4971
|
-
* Creates a new instance of the `M2M100ForConditionalGeneration` class.
|
|
4972
|
-
* @param {Object} config The model configuration.
|
|
4973
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4974
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4975
|
-
*/
|
|
4976
|
-
constructor(config, sessions, generation_config) {
|
|
4977
|
-
super(config, sessions);
|
|
4978
|
-
this.generation_config = generation_config;
|
|
4979
|
-
}
|
|
4980
|
-
};
|
|
4817
|
+
export class M2M100PreTrainedModel extends PreTrainedModel { };
|
|
4981
4818
|
|
|
4982
4819
|
export class M2M100Model extends M2M100PreTrainedModel { }
|
|
4983
4820
|
|
|
@@ -5069,7 +4906,7 @@ export class PyAnnoteModel extends PyAnnotePreTrainedModel { }
|
|
|
5069
4906
|
* **Example:** Load and run a `PyAnnoteForAudioFrameClassification` for speaker diarization.
|
|
5070
4907
|
*
|
|
5071
4908
|
* ```javascript
|
|
5072
|
-
* import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@
|
|
4909
|
+
* import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@huggingface/transformers';
|
|
5073
4910
|
*
|
|
5074
4911
|
* // Load model and processor
|
|
5075
4912
|
* const model_id = 'onnx-community/pyannote-segmentation-3.0';
|
|
@@ -5487,19 +5324,7 @@ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
|
|
|
5487
5324
|
/**
|
|
5488
5325
|
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
|
5489
5326
|
*/
|
|
5490
|
-
export class SpeechT5PreTrainedModel extends PreTrainedModel {
|
|
5491
|
-
|
|
5492
|
-
/**
|
|
5493
|
-
* Creates a new instance of the `SpeechT5ForTextToSpeech` class.
|
|
5494
|
-
* @param {Object} config The model configuration.
|
|
5495
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
5496
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5497
|
-
*/
|
|
5498
|
-
constructor(config, sessions, generation_config) {
|
|
5499
|
-
super(config, sessions);
|
|
5500
|
-
this.generation_config = generation_config;
|
|
5501
|
-
}
|
|
5502
|
-
};
|
|
5327
|
+
export class SpeechT5PreTrainedModel extends PreTrainedModel { };
|
|
5503
5328
|
|
|
5504
5329
|
/**
|
|
5505
5330
|
* The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
|
|
@@ -5660,18 +5485,7 @@ export class SpeechT5HifiGan extends PreTrainedModel {
|
|
|
5660
5485
|
|
|
5661
5486
|
//////////////////////////////////////////////////
|
|
5662
5487
|
// TrOCR models
|
|
5663
|
-
export class TrOCRPreTrainedModel extends PreTrainedModel {
|
|
5664
|
-
/**
|
|
5665
|
-
* Creates a new instance of the `TrOCRPreTrainedModel` class.
|
|
5666
|
-
* @param {Object} config The configuration of the model.
|
|
5667
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5668
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5669
|
-
*/
|
|
5670
|
-
constructor(config, session, generation_config) {
|
|
5671
|
-
super(config, session);
|
|
5672
|
-
this.generation_config = generation_config;
|
|
5673
|
-
}
|
|
5674
|
-
}
|
|
5488
|
+
export class TrOCRPreTrainedModel extends PreTrainedModel { }
|
|
5675
5489
|
|
|
5676
5490
|
/**
|
|
5677
5491
|
* The TrOCR Decoder with a language modeling head.
|
|
@@ -5686,18 +5500,7 @@ export class TrOCRForCausalLM extends TrOCRPreTrainedModel { }
|
|
|
5686
5500
|
/**
|
|
5687
5501
|
* The bare Mistral Model outputting raw hidden-states without any specific head on top.
|
|
5688
5502
|
*/
|
|
5689
|
-
export class MistralPreTrainedModel extends PreTrainedModel {
|
|
5690
|
-
/**
|
|
5691
|
-
* Creates a new instance of the `MistralPreTrainedModel` class.
|
|
5692
|
-
* @param {Object} config The configuration of the model.
|
|
5693
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5694
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5695
|
-
*/
|
|
5696
|
-
constructor(config, session, generation_config) {
|
|
5697
|
-
super(config, session);
|
|
5698
|
-
this.generation_config = generation_config;
|
|
5699
|
-
}
|
|
5700
|
-
}
|
|
5503
|
+
export class MistralPreTrainedModel extends PreTrainedModel { }
|
|
5701
5504
|
|
|
5702
5505
|
export class MistralModel extends MistralPreTrainedModel { }
|
|
5703
5506
|
|
|
@@ -5710,18 +5513,7 @@ export class MistralForCausalLM extends MistralPreTrainedModel { }
|
|
|
5710
5513
|
/**
|
|
5711
5514
|
* The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.
|
|
5712
5515
|
*/
|
|
5713
|
-
export class Starcoder2PreTrainedModel extends PreTrainedModel {
|
|
5714
|
-
/**
|
|
5715
|
-
* Creates a new instance of the `Starcoder2PreTrainedModel` class.
|
|
5716
|
-
* @param {Object} config The configuration of the model.
|
|
5717
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5718
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5719
|
-
*/
|
|
5720
|
-
constructor(config, session, generation_config) {
|
|
5721
|
-
super(config, session);
|
|
5722
|
-
this.generation_config = generation_config;
|
|
5723
|
-
}
|
|
5724
|
-
}
|
|
5516
|
+
export class Starcoder2PreTrainedModel extends PreTrainedModel { }
|
|
5725
5517
|
|
|
5726
5518
|
export class Starcoder2Model extends Starcoder2PreTrainedModel { }
|
|
5727
5519
|
|
|
@@ -5734,18 +5526,7 @@ export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel { }
|
|
|
5734
5526
|
/**
|
|
5735
5527
|
* The bare Falcon Model outputting raw hidden-states without any specific head on top.
|
|
5736
5528
|
*/
|
|
5737
|
-
export class FalconPreTrainedModel extends PreTrainedModel {
|
|
5738
|
-
/**
|
|
5739
|
-
* Creates a new instance of the `FalconPreTrainedModel` class.
|
|
5740
|
-
* @param {Object} config The configuration of the model.
|
|
5741
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5742
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5743
|
-
*/
|
|
5744
|
-
constructor(config, session, generation_config) {
|
|
5745
|
-
super(config, session);
|
|
5746
|
-
this.generation_config = generation_config;
|
|
5747
|
-
}
|
|
5748
|
-
}
|
|
5529
|
+
export class FalconPreTrainedModel extends PreTrainedModel { }
|
|
5749
5530
|
|
|
5750
5531
|
export class FalconModel extends FalconPreTrainedModel { }
|
|
5751
5532
|
|
|
@@ -5895,18 +5676,7 @@ export class SegformerForSemanticSegmentation extends SegformerPreTrainedModel {
|
|
|
5895
5676
|
|
|
5896
5677
|
//////////////////////////////////////////////////
|
|
5897
5678
|
// StableLm models
|
|
5898
|
-
export class StableLmPreTrainedModel extends PreTrainedModel {
|
|
5899
|
-
/**
|
|
5900
|
-
* Creates a new instance of the `StableLmPreTrainedModel` class.
|
|
5901
|
-
* @param {Object} config The configuration of the model.
|
|
5902
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5903
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5904
|
-
*/
|
|
5905
|
-
constructor(config, session, generation_config) {
|
|
5906
|
-
super(config, session);
|
|
5907
|
-
this.generation_config = generation_config;
|
|
5908
|
-
}
|
|
5909
|
-
}
|
|
5679
|
+
export class StableLmPreTrainedModel extends PreTrainedModel { }
|
|
5910
5680
|
|
|
5911
5681
|
/**
|
|
5912
5682
|
* The bare StableLm Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -6000,17 +5770,6 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE:
|
|
|
6000
5770
|
'past_key_values',
|
|
6001
5771
|
];
|
|
6002
5772
|
|
|
6003
|
-
/**
|
|
6004
|
-
* Creates a new instance of the `MusicgenForConditionalGeneration` class.
|
|
6005
|
-
* @param {Object} config The model configuration.
|
|
6006
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
6007
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
6008
|
-
*/
|
|
6009
|
-
constructor(config, sessions, generation_config) {
|
|
6010
|
-
super(config, sessions);
|
|
6011
|
-
this.generation_config = generation_config;
|
|
6012
|
-
}
|
|
6013
|
-
|
|
6014
5773
|
/**
|
|
6015
5774
|
* Apply the pattern mask to the final ids,
|
|
6016
5775
|
* then revert the pattern delay mask by filtering the pad token id in a single step.
|
|
@@ -6089,6 +5848,7 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE:
|
|
|
6089
5848
|
return audio_values;
|
|
6090
5849
|
}
|
|
6091
5850
|
}
|
|
5851
|
+
//////////////////////////////////////////////////
|
|
6092
5852
|
|
|
6093
5853
|
//////////////////////////////////////////////////
|
|
6094
5854
|
// MobileNetV1 models
|
|
@@ -6182,6 +5942,17 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
|
|
|
6182
5942
|
}
|
|
6183
5943
|
//////////////////////////////////////////////////
|
|
6184
5944
|
|
|
5945
|
+
//////////////////////////////////////////////////
|
|
5946
|
+
// Decision Transformer models
|
|
5947
|
+
export class DecisionTransformerPreTrainedModel extends PreTrainedModel { }
|
|
5948
|
+
|
|
5949
|
+
/**
|
|
5950
|
+
* The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL setting.
|
|
5951
|
+
* Refer to the paper for more details: https://arxiv.org/abs/2106.01345
|
|
5952
|
+
*/
|
|
5953
|
+
export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel { }
|
|
5954
|
+
|
|
5955
|
+
//////////////////////////////////////////////////
|
|
6185
5956
|
|
|
6186
5957
|
//////////////////////////////////////////////////
|
|
6187
5958
|
// AutoModels, used to simplify construction of PreTrainedModels
|
|
@@ -6220,7 +5991,7 @@ export class PretrainedMixin {
|
|
|
6220
5991
|
session_options = {},
|
|
6221
5992
|
} = {}) {
|
|
6222
5993
|
|
|
6223
|
-
|
|
5994
|
+
const options = {
|
|
6224
5995
|
progress_callback,
|
|
6225
5996
|
config,
|
|
6226
5997
|
cache_dir,
|
|
@@ -6239,7 +6010,7 @@ export class PretrainedMixin {
|
|
|
6239
6010
|
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
|
|
6240
6011
|
}
|
|
6241
6012
|
|
|
6242
|
-
for (
|
|
6013
|
+
for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
|
|
6243
6014
|
const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
|
|
6244
6015
|
if (!modelInfo) {
|
|
6245
6016
|
continue; // Item not found in this mapping
|
|
@@ -6294,6 +6065,10 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6294
6065
|
['rt_detr', ['RTDetrModel', RTDetrModel]],
|
|
6295
6066
|
['table-transformer', ['TableTransformerModel', TableTransformerModel]],
|
|
6296
6067
|
['vit', ['ViTModel', ViTModel]],
|
|
6068
|
+
['pvt', ['PvtModel', PvtModel]],
|
|
6069
|
+
['vit_msn', ['ViTMSNModel', ViTMSNModel]],
|
|
6070
|
+
['vit_mae', ['ViTMAEModel', ViTMAEModel]],
|
|
6071
|
+
['groupvit', ['GroupViTModel', GroupViTModel]],
|
|
6297
6072
|
['fastvit', ['FastViTModel', FastViTModel]],
|
|
6298
6073
|
['mobilevit', ['MobileViTModel', MobileViTModel]],
|
|
6299
6074
|
['mobilevitv2', ['MobileViTV2Model', MobileViTV2Model]],
|
|
@@ -6301,6 +6076,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6301
6076
|
['owlv2', ['Owlv2Model', Owlv2Model]],
|
|
6302
6077
|
['beit', ['BeitModel', BeitModel]],
|
|
6303
6078
|
['deit', ['DeiTModel', DeiTModel]],
|
|
6079
|
+
['hiera', ['HieraModel', HieraModel]],
|
|
6304
6080
|
['convnext', ['ConvNextModel', ConvNextModel]],
|
|
6305
6081
|
['convnextv2', ['ConvNextV2Model', ConvNextV2Model]],
|
|
6306
6082
|
['dinov2', ['Dinov2Model', Dinov2Model]],
|
|
@@ -6315,10 +6091,14 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6315
6091
|
['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],
|
|
6316
6092
|
['efficientnet', ['EfficientNetModel', EfficientNetModel]],
|
|
6317
6093
|
|
|
6094
|
+
['decision_transformer', ['DecisionTransformerModel', DecisionTransformerModel]],
|
|
6095
|
+
|
|
6318
6096
|
['mobilenet_v1', ['MobileNetV1Model', MobileNetV1Model]],
|
|
6319
6097
|
['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]],
|
|
6320
6098
|
['mobilenet_v3', ['MobileNetV3Model', MobileNetV3Model]],
|
|
6321
6099
|
['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]],
|
|
6100
|
+
|
|
6101
|
+
['maskformer', ['MaskFormerModel', MaskFormerModel]],
|
|
6322
6102
|
]);
|
|
6323
6103
|
|
|
6324
6104
|
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
|
|
@@ -6337,6 +6117,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
|
|
|
6337
6117
|
|
|
6338
6118
|
const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
|
|
6339
6119
|
['bloom', ['BloomModel', BloomModel]],
|
|
6120
|
+
['jais', ['JAISModel', JAISModel]],
|
|
6340
6121
|
['gpt2', ['GPT2Model', GPT2Model]],
|
|
6341
6122
|
['gptj', ['GPTJModel', GPTJModel]],
|
|
6342
6123
|
['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]],
|
|
@@ -6344,6 +6125,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
|
|
|
6344
6125
|
['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]],
|
|
6345
6126
|
['codegen', ['CodeGenModel', CodeGenModel]],
|
|
6346
6127
|
['llama', ['LlamaModel', LlamaModel]],
|
|
6128
|
+
['granite', ['GraniteModel', GraniteModel]],
|
|
6347
6129
|
['cohere', ['CohereModel', CohereModel]],
|
|
6348
6130
|
['gemma', ['GemmaModel', GemmaModel]],
|
|
6349
6131
|
['gemma2', ['Gemma2Model', Gemma2Model]],
|
|
@@ -6425,12 +6207,14 @@ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
|
|
|
6425
6207
|
const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
|
|
6426
6208
|
['bloom', ['BloomForCausalLM', BloomForCausalLM]],
|
|
6427
6209
|
['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]],
|
|
6210
|
+
['jais', ['JAISLMHeadModel', JAISLMHeadModel]],
|
|
6428
6211
|
['gptj', ['GPTJForCausalLM', GPTJForCausalLM]],
|
|
6429
6212
|
['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]],
|
|
6430
6213
|
['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]],
|
|
6431
6214
|
['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]],
|
|
6432
6215
|
['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]],
|
|
6433
6216
|
['llama', ['LlamaForCausalLM', LlamaForCausalLM]],
|
|
6217
|
+
['granite', ['GraniteForCausalLM', GraniteForCausalLM]],
|
|
6434
6218
|
['cohere', ['CohereForCausalLM', CohereForCausalLM]],
|
|
6435
6219
|
['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
|
|
6436
6220
|
['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
|
|
@@ -6501,11 +6285,14 @@ const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
|
|
|
6501
6285
|
|
|
6502
6286
|
const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([
|
|
6503
6287
|
['vit', ['ViTForImageClassification', ViTForImageClassification]],
|
|
6288
|
+
['pvt', ['PvtForImageClassification', PvtForImageClassification]],
|
|
6289
|
+
['vit_msn', ['ViTMSNForImageClassification', ViTMSNForImageClassification]],
|
|
6504
6290
|
['fastvit', ['FastViTForImageClassification', FastViTForImageClassification]],
|
|
6505
6291
|
['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]],
|
|
6506
6292
|
['mobilevitv2', ['MobileViTV2ForImageClassification', MobileViTV2ForImageClassification]],
|
|
6507
6293
|
['beit', ['BeitForImageClassification', BeitForImageClassification]],
|
|
6508
6294
|
['deit', ['DeiTForImageClassification', DeiTForImageClassification]],
|
|
6295
|
+
['hiera', ['HieraForImageClassification', HieraForImageClassification]],
|
|
6509
6296
|
['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]],
|
|
6510
6297
|
['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]],
|
|
6511
6298
|
['dinov2', ['Dinov2ForImageClassification', Dinov2ForImageClassification]],
|
|
@@ -6532,12 +6319,19 @@ const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
|
|
|
6532
6319
|
]);
|
|
6533
6320
|
|
|
6534
6321
|
const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
6322
|
+
// TODO: Do not add new models here
|
|
6535
6323
|
['detr', ['DetrForSegmentation', DetrForSegmentation]],
|
|
6536
6324
|
['clipseg', ['CLIPSegForImageSegmentation', CLIPSegForImageSegmentation]],
|
|
6537
6325
|
]);
|
|
6538
6326
|
|
|
6539
6327
|
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
6540
6328
|
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
|
|
6329
|
+
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
|
|
6330
|
+
]);
|
|
6331
|
+
|
|
6332
|
+
const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
6333
|
+
['detr', ['DetrForSegmentation', DetrForSegmentation]],
|
|
6334
|
+
['maskformer', ['MaskFormerForInstanceSegmentation', MaskFormerForInstanceSegmentation]],
|
|
6541
6335
|
]);
|
|
6542
6336
|
|
|
6543
6337
|
const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
|
|
@@ -6586,6 +6380,12 @@ const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
|
|
|
6586
6380
|
['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]],
|
|
6587
6381
|
['depth_anything', ['DepthAnythingForDepthEstimation', DepthAnythingForDepthEstimation]],
|
|
6588
6382
|
['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
|
|
6383
|
+
['sapiens', ['SapiensForDepthEstimation', SapiensForDepthEstimation]],
|
|
6384
|
+
['depth_pro', ['DepthProForDepthEstimation', DepthProForDepthEstimation]],
|
|
6385
|
+
])
|
|
6386
|
+
|
|
6387
|
+
const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
|
|
6388
|
+
['sapiens', ['SapiensForNormalEstimation', SapiensForNormalEstimation]],
|
|
6589
6389
|
])
|
|
6590
6390
|
|
|
6591
6391
|
// NOTE: This is custom to Transformers.js, and is necessary because certain models
|
|
@@ -6610,10 +6410,12 @@ const MODEL_CLASS_TYPE_MAPPING = [
|
|
|
6610
6410
|
[MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText],
|
|
6611
6411
|
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6612
6412
|
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6413
|
+
[MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6613
6414
|
[MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6614
6415
|
[MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6615
6416
|
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6616
6417
|
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6418
|
+
[MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6617
6419
|
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6618
6420
|
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6619
6421
|
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
|
|
@@ -6811,6 +6613,17 @@ export class AutoModelForSemanticSegmentation extends PretrainedMixin {
|
|
|
6811
6613
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES];
|
|
6812
6614
|
}
|
|
6813
6615
|
|
|
6616
|
+
/**
|
|
6617
|
+
* Helper class which is used to instantiate pretrained universal image segmentation models with the `from_pretrained` function.
|
|
6618
|
+
* The chosen model class is determined by the type specified in the model config.
|
|
6619
|
+
*
|
|
6620
|
+
* @example
|
|
6621
|
+
* let model = await AutoModelForUniversalSegmentation.from_pretrained('hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation');
|
|
6622
|
+
*/
|
|
6623
|
+
export class AutoModelForUniversalSegmentation extends PretrainedMixin {
|
|
6624
|
+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES];
|
|
6625
|
+
}
|
|
6626
|
+
|
|
6814
6627
|
/**
|
|
6815
6628
|
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
|
|
6816
6629
|
* The chosen model class is determined by the type specified in the model config.
|
|
@@ -6870,6 +6683,10 @@ export class AutoModelForDepthEstimation extends PretrainedMixin {
|
|
|
6870
6683
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES];
|
|
6871
6684
|
}
|
|
6872
6685
|
|
|
6686
|
+
export class AutoModelForNormalEstimation extends PretrainedMixin {
|
|
6687
|
+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES];
|
|
6688
|
+
}
|
|
6689
|
+
|
|
6873
6690
|
export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
|
|
6874
6691
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
|
|
6875
6692
|
}
|