@huggingface/transformers 3.0.0-alpha.9 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +33 -22
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +2515 -2525
- package/dist/transformers.cjs.map +1 -1
- package/dist/transformers.js +3529 -3455
- package/dist/transformers.js.map +1 -1
- package/dist/transformers.min.cjs +25 -25
- package/dist/transformers.min.cjs.map +1 -1
- package/dist/transformers.min.js +39 -40
- package/dist/transformers.min.js.map +1 -1
- package/dist/transformers.min.mjs +56 -57
- package/dist/transformers.min.mjs.map +1 -1
- package/dist/transformers.mjs +2551 -2538
- package/dist/transformers.mjs.map +1 -1
- package/package.json +14 -13
- package/src/backends/onnx.js +24 -19
- package/src/configs.js +19 -4
- package/src/env.js +5 -9
- package/src/generation/logits_process.js +40 -37
- package/src/models.js +326 -514
- package/src/ops/registry.js +14 -3
- package/src/pipelines.js +5 -4
- package/src/processors.js +390 -351
- package/src/tokenizers.js +140 -175
- package/src/utils/constants.js +1 -1
- package/src/utils/core.js +12 -0
- package/src/utils/data-structures.js +13 -11
- package/src/utils/hub.js +1 -1
- package/src/utils/maths.js +14 -5
- package/src/utils/tensor.js +60 -13
- package/types/backends/onnx.d.ts +5 -2
- package/types/backends/onnx.d.ts.map +1 -1
- package/types/configs.d.ts +29 -3
- package/types/configs.d.ts.map +1 -1
- package/types/env.d.ts +4 -2
- package/types/env.d.ts.map +1 -1
- package/types/generation/logits_process.d.ts.map +1 -1
- package/types/models.d.ts +116 -289
- package/types/models.d.ts.map +1 -1
- package/types/ops/registry.d.ts +6 -6
- package/types/ops/registry.d.ts.map +1 -1
- package/types/pipelines.d.ts +1 -2
- package/types/pipelines.d.ts.map +1 -1
- package/types/processors.d.ts +55 -51
- package/types/processors.d.ts.map +1 -1
- package/types/tokenizers.d.ts +23 -32
- package/types/tokenizers.d.ts.map +1 -1
- package/types/utils/constants.d.ts +1 -1
- package/types/utils/constants.d.ts.map +1 -1
- package/types/utils/core.d.ts +7 -0
- package/types/utils/core.d.ts.map +1 -1
- package/types/utils/data-structures.d.ts +6 -6
- package/types/utils/data-structures.d.ts.map +1 -1
- package/types/utils/hub.d.ts +1 -1
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/maths.d.ts +2 -2
- package/types/utils/maths.d.ts.map +1 -1
- package/types/utils/tensor.d.ts +27 -1
- package/types/utils/tensor.d.ts.map +1 -1
package/src/models.js
CHANGED
|
@@ -71,6 +71,10 @@ import {
|
|
|
71
71
|
getModelJSON,
|
|
72
72
|
} from './utils/hub.js';
|
|
73
73
|
|
|
74
|
+
import {
|
|
75
|
+
GITHUB_ISSUE_URL,
|
|
76
|
+
} from './utils/constants.js';
|
|
77
|
+
|
|
74
78
|
import {
|
|
75
79
|
LogitsProcessorList,
|
|
76
80
|
ForcedBOSTokenLogitsProcessor,
|
|
@@ -142,11 +146,12 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
|
|
|
142
146
|
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
|
|
143
147
|
* @param {string} fileName The name of the model file.
|
|
144
148
|
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
|
|
145
|
-
* @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
|
|
149
|
+
* @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
|
|
146
150
|
* @private
|
|
147
151
|
*/
|
|
148
152
|
async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
149
|
-
|
|
153
|
+
const custom_config = options.config?.['transformers.js_config'] ?? {};
|
|
154
|
+
let device = options.device ?? custom_config.device;
|
|
150
155
|
if (device && typeof device !== 'string') {
|
|
151
156
|
if (device.hasOwnProperty(fileName)) {
|
|
152
157
|
device = device[fileName];
|
|
@@ -164,7 +169,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
164
169
|
|
|
165
170
|
// If options.dtype is specified, we use it to choose the suffix for the model file.
|
|
166
171
|
// Otherwise, we use the default dtype for the device.
|
|
167
|
-
let dtype = options.dtype;
|
|
172
|
+
let dtype = options.dtype ?? custom_config.dtype;
|
|
168
173
|
if (typeof dtype !== 'string') {
|
|
169
174
|
if (dtype && dtype.hasOwnProperty(fileName)) {
|
|
170
175
|
dtype = dtype[fileName];
|
|
@@ -182,27 +187,54 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
182
187
|
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
|
|
183
188
|
}
|
|
184
189
|
|
|
190
|
+
// Only valid for models with a decoder
|
|
191
|
+
const kv_cache_dtype = custom_config.kv_cache_dtype
|
|
192
|
+
? (typeof custom_config.kv_cache_dtype === 'string'
|
|
193
|
+
? custom_config.kv_cache_dtype
|
|
194
|
+
: custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
|
|
195
|
+
: undefined;
|
|
196
|
+
|
|
197
|
+
if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
|
|
198
|
+
throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
const session_config = {
|
|
202
|
+
dtype: selectedDtype,
|
|
203
|
+
kv_cache_dtype,
|
|
204
|
+
}
|
|
205
|
+
|
|
185
206
|
// Construct the model file name
|
|
186
207
|
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
|
|
187
208
|
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
|
|
188
209
|
|
|
189
|
-
const session_options = { ...options.session_options }
|
|
210
|
+
const session_options = { ...options.session_options };
|
|
190
211
|
|
|
191
212
|
// Overwrite `executionProviders` if not specified
|
|
192
213
|
session_options.executionProviders ??= executionProviders;
|
|
193
214
|
|
|
215
|
+
// Overwrite `freeDimensionOverrides` if specified in config and not set in session options
|
|
216
|
+
const free_dimension_overrides = custom_config.free_dimension_overrides;
|
|
217
|
+
if (free_dimension_overrides) {
|
|
218
|
+
session_options.freeDimensionOverrides ??= free_dimension_overrides;
|
|
219
|
+
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
|
|
220
|
+
console.warn(
|
|
221
|
+
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
|
|
222
|
+
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
|
|
223
|
+
);
|
|
224
|
+
}
|
|
194
225
|
|
|
195
226
|
const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
|
|
196
227
|
|
|
197
228
|
// handle onnx external data files
|
|
229
|
+
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
|
|
198
230
|
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
|
|
199
231
|
let externalDataPromises = [];
|
|
200
|
-
if (
|
|
201
|
-
|
|
232
|
+
if (use_external_data_format && (
|
|
233
|
+
use_external_data_format === true ||
|
|
202
234
|
(
|
|
203
|
-
typeof
|
|
204
|
-
|
|
205
|
-
|
|
235
|
+
typeof use_external_data_format === 'object' &&
|
|
236
|
+
use_external_data_format.hasOwnProperty(fileName) &&
|
|
237
|
+
use_external_data_format[fileName] === true
|
|
206
238
|
)
|
|
207
239
|
)) {
|
|
208
240
|
if (apis.IS_NODE_ENV) {
|
|
@@ -236,6 +268,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
236
268
|
});
|
|
237
269
|
if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
|
|
238
270
|
// Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
|
|
271
|
+
/** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
|
|
239
272
|
const preferredOutputLocation = {};
|
|
240
273
|
for (const key in shapes) {
|
|
241
274
|
preferredOutputLocation[key] = 'gpu-buffer';
|
|
@@ -245,7 +278,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
245
278
|
}
|
|
246
279
|
|
|
247
280
|
const buffer = await bufferPromise;
|
|
248
|
-
|
|
281
|
+
|
|
282
|
+
return { buffer, session_options, session_config };
|
|
249
283
|
}
|
|
250
284
|
|
|
251
285
|
/**
|
|
@@ -260,13 +294,30 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
260
294
|
async function constructSessions(pretrained_model_name_or_path, names, options) {
|
|
261
295
|
return Object.fromEntries(await Promise.all(
|
|
262
296
|
Object.keys(names).map(async (name) => {
|
|
263
|
-
const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options);
|
|
264
|
-
const session = await createInferenceSession(buffer, session_options);
|
|
297
|
+
const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
|
|
298
|
+
const session = await createInferenceSession(buffer, session_options, session_config);
|
|
265
299
|
return [name, session];
|
|
266
300
|
})
|
|
267
301
|
));
|
|
268
302
|
}
|
|
269
303
|
|
|
304
|
+
/**
|
|
305
|
+
* Helper function to load multiple optional configuration files
|
|
306
|
+
* @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
|
|
307
|
+
* @param {Record<string, string>} names The names of the config files to load.
|
|
308
|
+
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
|
|
309
|
+
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects.
|
|
310
|
+
* @private
|
|
311
|
+
*/
|
|
312
|
+
async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
|
|
313
|
+
return Object.fromEntries(await Promise.all(
|
|
314
|
+
Object.keys(names).map(async (name) => {
|
|
315
|
+
const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options);
|
|
316
|
+
return [name, config];
|
|
317
|
+
})
|
|
318
|
+
));
|
|
319
|
+
}
|
|
320
|
+
|
|
270
321
|
/**
|
|
271
322
|
* Validate model inputs
|
|
272
323
|
* @param {Object} session The InferenceSession object that will be run.
|
|
@@ -393,37 +444,6 @@ function toI64Tensor(items) {
|
|
|
393
444
|
}
|
|
394
445
|
}
|
|
395
446
|
|
|
396
|
-
/**
|
|
397
|
-
* Prepares an attention mask for a sequence of tokens based on configuration options.
|
|
398
|
-
* @param {Object} self The calling object instance.
|
|
399
|
-
* @param {Tensor} tokens The input tokens.
|
|
400
|
-
* @returns {Tensor} The attention mask tensor.
|
|
401
|
-
* @private
|
|
402
|
-
*/
|
|
403
|
-
function prepareAttentionMask(self, tokens) {
|
|
404
|
-
|
|
405
|
-
// Prepare attention mask
|
|
406
|
-
let pad_token_id = self.config.pad_token_id ?? null;
|
|
407
|
-
let eos_token_id = self.config.eos_token_id ?? null;
|
|
408
|
-
if (isIntegralNumber(eos_token_id)) {
|
|
409
|
-
eos_token_id = [eos_token_id];
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
let is_pad_token_in_inputs = tokens.indexOf(pad_token_id) !== -1;
|
|
413
|
-
let is_pad_token_not_equal_to_eos_token_id = (eos_token_id === null) || !eos_token_id.includes(pad_token_id)
|
|
414
|
-
|
|
415
|
-
if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
|
|
416
|
-
let data = BigInt64Array.from(
|
|
417
|
-
// Note: != so that int matches bigint
|
|
418
|
-
// @ts-ignore
|
|
419
|
-
tokens.data.map(x => x != pad_token_id)
|
|
420
|
-
)
|
|
421
|
-
return new Tensor('int64', data, tokens.dims)
|
|
422
|
-
} else {
|
|
423
|
-
return ones_like(tokens);
|
|
424
|
-
}
|
|
425
|
-
}
|
|
426
|
-
|
|
427
447
|
/**
|
|
428
448
|
* Creates a boolean tensor with a single value.
|
|
429
449
|
* @param {boolean} value The value of the tensor.
|
|
@@ -694,8 +714,8 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
|
|
|
694
714
|
} else {
|
|
695
715
|
return decoder_prepare_inputs_for_generation(self, ...args);
|
|
696
716
|
}
|
|
697
|
-
|
|
698
717
|
}
|
|
718
|
+
|
|
699
719
|
//////////////////////////////////////////////////
|
|
700
720
|
|
|
701
721
|
//////////////////////////////////////////////////
|
|
@@ -709,12 +729,14 @@ export class PreTrainedModel extends Callable {
|
|
|
709
729
|
* Creates a new instance of the `PreTrainedModel` class.
|
|
710
730
|
* @param {import('./configs.js').PretrainedConfig} config The model configuration.
|
|
711
731
|
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
732
|
+
* @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json).
|
|
712
733
|
*/
|
|
713
|
-
constructor(config, sessions) {
|
|
734
|
+
constructor(config, sessions, configs) {
|
|
714
735
|
super();
|
|
715
736
|
|
|
716
737
|
this.config = config;
|
|
717
738
|
this.sessions = sessions;
|
|
739
|
+
this.configs = configs;
|
|
718
740
|
|
|
719
741
|
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
|
|
720
742
|
const modelType = MODEL_TYPE_MAPPING.get(modelName);
|
|
@@ -830,7 +852,9 @@ export class PreTrainedModel extends Callable {
|
|
|
830
852
|
constructSessions(pretrained_model_name_or_path, {
|
|
831
853
|
model: options.model_file_name ?? 'model',
|
|
832
854
|
}, options),
|
|
833
|
-
|
|
855
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
856
|
+
generation_config: 'generation_config.json',
|
|
857
|
+
}, options),
|
|
834
858
|
]);
|
|
835
859
|
|
|
836
860
|
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
|
|
@@ -839,7 +863,9 @@ export class PreTrainedModel extends Callable {
|
|
|
839
863
|
model: 'encoder_model',
|
|
840
864
|
decoder_model_merged: 'decoder_model_merged',
|
|
841
865
|
}, options),
|
|
842
|
-
|
|
866
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
867
|
+
generation_config: 'generation_config.json',
|
|
868
|
+
}, options),
|
|
843
869
|
]);
|
|
844
870
|
|
|
845
871
|
} else if (modelType === MODEL_TYPES.MaskGeneration) {
|
|
@@ -869,7 +895,9 @@ export class PreTrainedModel extends Callable {
|
|
|
869
895
|
}
|
|
870
896
|
info = await Promise.all([
|
|
871
897
|
constructSessions(pretrained_model_name_or_path, sessions, options),
|
|
872
|
-
|
|
898
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
899
|
+
generation_config: 'generation_config.json',
|
|
900
|
+
}, options),
|
|
873
901
|
]);
|
|
874
902
|
|
|
875
903
|
} else if (modelType === MODEL_TYPES.Musicgen) {
|
|
@@ -879,12 +907,14 @@ export class PreTrainedModel extends Callable {
|
|
|
879
907
|
decoder_model_merged: 'decoder_model_merged',
|
|
880
908
|
encodec_decode: 'encodec_decode',
|
|
881
909
|
}, options),
|
|
882
|
-
|
|
910
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
911
|
+
generation_config: 'generation_config.json',
|
|
912
|
+
}, options),
|
|
883
913
|
]);
|
|
884
914
|
|
|
885
915
|
} else { // should be MODEL_TYPES.EncoderOnly
|
|
886
916
|
if (modelType !== MODEL_TYPES.EncoderOnly) {
|
|
887
|
-
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at
|
|
917
|
+
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
|
|
888
918
|
}
|
|
889
919
|
info = await Promise.all([
|
|
890
920
|
constructSessions(pretrained_model_name_or_path, {
|
|
@@ -917,6 +947,14 @@ export class PreTrainedModel extends Callable {
|
|
|
917
947
|
return await this._forward(this, model_inputs);
|
|
918
948
|
}
|
|
919
949
|
|
|
950
|
+
/**
|
|
951
|
+
* Get the model's generation config, if it exists.
|
|
952
|
+
* @returns {GenerationConfig|null} The model's generation config if it exists, otherwise `null`.
|
|
953
|
+
*/
|
|
954
|
+
get generation_config() {
|
|
955
|
+
return this.configs?.generation_config ?? null;
|
|
956
|
+
}
|
|
957
|
+
|
|
920
958
|
/**
|
|
921
959
|
* This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
|
|
922
960
|
* instances used for multinomial sampling.
|
|
@@ -1096,9 +1134,7 @@ export class PreTrainedModel extends Callable {
|
|
|
1096
1134
|
const gen_config = new cls(config);
|
|
1097
1135
|
|
|
1098
1136
|
// Apply model's generation config, if it exists
|
|
1099
|
-
|
|
1100
|
-
Object.assign(gen_config, this.generation_config);
|
|
1101
|
-
}
|
|
1137
|
+
Object.assign(gen_config, this.generation_config ?? {});
|
|
1102
1138
|
|
|
1103
1139
|
// Next, use any generation config specified by the user
|
|
1104
1140
|
// when calling `generate`
|
|
@@ -1458,13 +1494,12 @@ export class PreTrainedModel extends Callable {
|
|
|
1458
1494
|
// - GenerationMode.BEAM_SEARCH
|
|
1459
1495
|
// - GenerationMode.BEAM_SAMPLE
|
|
1460
1496
|
////////////////////////////////////////////////////
|
|
1461
|
-
let
|
|
1497
|
+
let outputs;
|
|
1462
1498
|
let attentions = {};
|
|
1463
1499
|
while (true) {
|
|
1464
1500
|
// prepare model inputs
|
|
1465
1501
|
model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs, generation_config);
|
|
1466
|
-
|
|
1467
|
-
const outputs = await this.forward(model_inputs);
|
|
1502
|
+
outputs = await this.forward(model_inputs);
|
|
1468
1503
|
|
|
1469
1504
|
if (generation_config.output_attentions && generation_config.return_dict_in_generate) {
|
|
1470
1505
|
// Get attentions if they are present
|
|
@@ -1511,10 +1546,6 @@ export class PreTrainedModel extends Callable {
|
|
|
1511
1546
|
|
|
1512
1547
|
const stop = prepared_stopping_criteria(all_input_ids);
|
|
1513
1548
|
if (stop.every(x => x)) {
|
|
1514
|
-
if (generation_config.return_dict_in_generate) {
|
|
1515
|
-
// Get past key values without disposing buffers
|
|
1516
|
-
past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, false);
|
|
1517
|
-
}
|
|
1518
1549
|
break;
|
|
1519
1550
|
}
|
|
1520
1551
|
|
|
@@ -1527,6 +1558,9 @@ export class PreTrainedModel extends Callable {
|
|
|
1527
1558
|
streamer.end();
|
|
1528
1559
|
}
|
|
1529
1560
|
|
|
1561
|
+
// Retrieve and dispose all final past key values (including encoder attentions)
|
|
1562
|
+
const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true);
|
|
1563
|
+
|
|
1530
1564
|
// TODO: ensure all_input_ids is padded correctly...
|
|
1531
1565
|
const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]);
|
|
1532
1566
|
|
|
@@ -1540,6 +1574,12 @@ export class PreTrainedModel extends Callable {
|
|
|
1540
1574
|
// logits,
|
|
1541
1575
|
}
|
|
1542
1576
|
} else {
|
|
1577
|
+
// Dispose all remaining tensors
|
|
1578
|
+
for (const tensor of Object.values(outputs)) {
|
|
1579
|
+
if (tensor.location === 'gpu-buffer') {
|
|
1580
|
+
tensor.dispose();
|
|
1581
|
+
}
|
|
1582
|
+
}
|
|
1543
1583
|
return sequences;
|
|
1544
1584
|
}
|
|
1545
1585
|
}
|
|
@@ -1549,31 +1589,32 @@ export class PreTrainedModel extends Callable {
|
|
|
1549
1589
|
*
|
|
1550
1590
|
* @param {Object} decoderResults The decoder results object.
|
|
1551
1591
|
* @param {Object} pastKeyValues The previous past key values.
|
|
1552
|
-
* @param {boolean} [dispose=true] Whether to dispose of the old gpu buffer.
|
|
1553
1592
|
* @returns {Object} An object containing past key values.
|
|
1554
1593
|
*/
|
|
1555
|
-
getPastKeyValues(decoderResults, pastKeyValues,
|
|
1594
|
+
getPastKeyValues(decoderResults, pastKeyValues, disposeEncoderPKVs = false) {
|
|
1556
1595
|
const pkvs = Object.create(null);
|
|
1557
1596
|
|
|
1558
1597
|
for (const name in decoderResults) {
|
|
1559
1598
|
if (name.startsWith('present')) {
|
|
1560
1599
|
const newName = name.replace('present', 'past_key_values');
|
|
1561
|
-
|
|
1562
|
-
if (
|
|
1563
|
-
// Optimization introduced by optimum to reuse past key values.
|
|
1564
|
-
// outputs with the previous past key values.
|
|
1600
|
+
const is_encoder_pkv = name.includes('encoder');
|
|
1601
|
+
if (is_encoder_pkv && pastKeyValues) {
|
|
1602
|
+
// Optimization introduced by optimum to reuse past key values.
|
|
1603
|
+
// So, we just replace the constant outputs (`decoderResults[name]`) with the previous past key values.
|
|
1565
1604
|
// https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
|
|
1566
1605
|
pkvs[newName] = pastKeyValues[newName];
|
|
1567
|
-
} else {
|
|
1568
|
-
if (dispose && pastKeyValues) {
|
|
1569
|
-
// Free old gpu buffer
|
|
1570
|
-
const t = pastKeyValues[newName];
|
|
1571
|
-
if (t.location === 'gpu-buffer') {
|
|
1572
|
-
t.dispose();
|
|
1573
|
-
}
|
|
1574
|
-
}
|
|
1606
|
+
} else { // decoder or using first encoder PKVs
|
|
1575
1607
|
pkvs[newName] = decoderResults[name];
|
|
1576
1608
|
}
|
|
1609
|
+
|
|
1610
|
+
if (pastKeyValues && (!is_encoder_pkv || disposeEncoderPKVs)) {
|
|
1611
|
+
// - Always dispose decoder PKVs
|
|
1612
|
+
// - Only dispose encoder past key values when requested (after generation)
|
|
1613
|
+
const t = pastKeyValues[newName];
|
|
1614
|
+
if (t.location === 'gpu-buffer') {
|
|
1615
|
+
t.dispose();
|
|
1616
|
+
}
|
|
1617
|
+
}
|
|
1577
1618
|
}
|
|
1578
1619
|
}
|
|
1579
1620
|
return pkvs;
|
|
@@ -1611,9 +1652,8 @@ export class PreTrainedModel extends Callable {
|
|
|
1611
1652
|
if (pastKeyValues) {
|
|
1612
1653
|
Object.assign(decoderFeeds, pastKeyValues)
|
|
1613
1654
|
} else {
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
const dtype = this.custom_config.kv_cache_dtype ?? 'float32';
|
|
1655
|
+
const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
|
|
1656
|
+
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
|
|
1617
1657
|
const empty = (dtype === 'float16') ? new Uint16Array() : [];
|
|
1618
1658
|
|
|
1619
1659
|
const shapes = getKeyValueShapes(this.config);
|
|
@@ -2506,17 +2546,6 @@ export class T5PreTrainedModel extends PreTrainedModel {
|
|
|
2506
2546
|
'decoder_attention_mask',
|
|
2507
2547
|
'past_key_values',
|
|
2508
2548
|
];
|
|
2509
|
-
|
|
2510
|
-
/**
|
|
2511
|
-
* Creates a new instance of the `T5PreTrainedModel` class.
|
|
2512
|
-
* @param {Object} config The model configuration.
|
|
2513
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2514
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2515
|
-
*/
|
|
2516
|
-
constructor(config, sessions, generation_config) {
|
|
2517
|
-
super(config, sessions);
|
|
2518
|
-
this.generation_config = generation_config;
|
|
2519
|
-
}
|
|
2520
2549
|
};
|
|
2521
2550
|
|
|
2522
2551
|
export class T5Model extends T5PreTrainedModel { }
|
|
@@ -2533,18 +2562,7 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { }
|
|
|
2533
2562
|
/**
|
|
2534
2563
|
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
|
2535
2564
|
*/
|
|
2536
|
-
export class LongT5PreTrainedModel extends PreTrainedModel {
|
|
2537
|
-
/**
|
|
2538
|
-
* Creates a new instance of the `LongT5ForConditionalGeneration` class.
|
|
2539
|
-
* @param {Object} config The model configuration.
|
|
2540
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2541
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2542
|
-
*/
|
|
2543
|
-
constructor(config, sessions, generation_config) {
|
|
2544
|
-
super(config, sessions);
|
|
2545
|
-
this.generation_config = generation_config;
|
|
2546
|
-
}
|
|
2547
|
-
};
|
|
2565
|
+
export class LongT5PreTrainedModel extends PreTrainedModel { };
|
|
2548
2566
|
|
|
2549
2567
|
/**
|
|
2550
2568
|
* The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -2560,19 +2578,7 @@ export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { }
|
|
|
2560
2578
|
|
|
2561
2579
|
//////////////////////////////////////////////////
|
|
2562
2580
|
// MT5 models
|
|
2563
|
-
export class MT5PreTrainedModel extends PreTrainedModel {
|
|
2564
|
-
|
|
2565
|
-
/**
|
|
2566
|
-
* Creates a new instance of the `MT5ForConditionalGeneration` class.
|
|
2567
|
-
* @param {Object} config The model configuration.
|
|
2568
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2569
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2570
|
-
*/
|
|
2571
|
-
constructor(config, sessions, generation_config) {
|
|
2572
|
-
super(config, sessions);
|
|
2573
|
-
this.generation_config = generation_config;
|
|
2574
|
-
}
|
|
2575
|
-
};
|
|
2581
|
+
export class MT5PreTrainedModel extends PreTrainedModel { };
|
|
2576
2582
|
|
|
2577
2583
|
export class MT5Model extends MT5PreTrainedModel { }
|
|
2578
2584
|
|
|
@@ -2584,19 +2590,7 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { }
|
|
|
2584
2590
|
|
|
2585
2591
|
//////////////////////////////////////////////////
|
|
2586
2592
|
// Bart models
|
|
2587
|
-
export class BartPretrainedModel extends PreTrainedModel {
|
|
2588
|
-
|
|
2589
|
-
/**
|
|
2590
|
-
* Creates a new instance of the `BartForConditionalGeneration` class.
|
|
2591
|
-
* @param {Object} config The model configuration.
|
|
2592
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2593
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2594
|
-
*/
|
|
2595
|
-
constructor(config, sessions, generation_config) {
|
|
2596
|
-
super(config, sessions);
|
|
2597
|
-
this.generation_config = generation_config;
|
|
2598
|
-
}
|
|
2599
|
-
};
|
|
2593
|
+
export class BartPretrainedModel extends PreTrainedModel { };
|
|
2600
2594
|
|
|
2601
2595
|
/**
|
|
2602
2596
|
* The bare BART Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2627,19 +2621,7 @@ export class BartForSequenceClassification extends BartPretrainedModel {
|
|
|
2627
2621
|
|
|
2628
2622
|
//////////////////////////////////////////////////
|
|
2629
2623
|
// MBart models
|
|
2630
|
-
export class MBartPreTrainedModel extends PreTrainedModel {
|
|
2631
|
-
|
|
2632
|
-
/**
|
|
2633
|
-
* Creates a new instance of the `MBartForConditionalGeneration` class.
|
|
2634
|
-
* @param {Object} config The model configuration.
|
|
2635
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2636
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2637
|
-
*/
|
|
2638
|
-
constructor(config, sessions, generation_config) {
|
|
2639
|
-
super(config, sessions);
|
|
2640
|
-
this.generation_config = generation_config;
|
|
2641
|
-
}
|
|
2642
|
-
};
|
|
2624
|
+
export class MBartPreTrainedModel extends PreTrainedModel { };
|
|
2643
2625
|
|
|
2644
2626
|
/**
|
|
2645
2627
|
* The bare MBART Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2673,19 +2655,7 @@ export class MBartForCausalLM extends MBartPreTrainedModel { }
|
|
|
2673
2655
|
|
|
2674
2656
|
//////////////////////////////////////////////////
|
|
2675
2657
|
// Blenderbot models
|
|
2676
|
-
export class BlenderbotPreTrainedModel extends PreTrainedModel {
|
|
2677
|
-
|
|
2678
|
-
/**
|
|
2679
|
-
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
|
|
2680
|
-
* @param {Object} config The model configuration.
|
|
2681
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2682
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2683
|
-
*/
|
|
2684
|
-
constructor(config, sessions, generation_config) {
|
|
2685
|
-
super(config, sessions);
|
|
2686
|
-
this.generation_config = generation_config;
|
|
2687
|
-
}
|
|
2688
|
-
};
|
|
2658
|
+
export class BlenderbotPreTrainedModel extends PreTrainedModel { };
|
|
2689
2659
|
|
|
2690
2660
|
/**
|
|
2691
2661
|
* The bare Blenderbot Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2701,19 +2671,7 @@ export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedMode
|
|
|
2701
2671
|
|
|
2702
2672
|
//////////////////////////////////////////////////
|
|
2703
2673
|
// Blenderbot models
|
|
2704
|
-
export class BlenderbotSmallPreTrainedModel extends PreTrainedModel {
|
|
2705
|
-
|
|
2706
|
-
/**
|
|
2707
|
-
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
|
|
2708
|
-
* @param {Object} config The model configuration.
|
|
2709
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2710
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2711
|
-
*/
|
|
2712
|
-
constructor(config, sessions, generation_config) {
|
|
2713
|
-
super(config, sessions);
|
|
2714
|
-
this.generation_config = generation_config;
|
|
2715
|
-
}
|
|
2716
|
-
};
|
|
2674
|
+
export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { };
|
|
2717
2675
|
|
|
2718
2676
|
/**
|
|
2719
2677
|
* The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top.
|
|
@@ -2962,17 +2920,6 @@ export class WhisperPreTrainedModel extends PreTrainedModel {
|
|
|
2962
2920
|
'decoder_attention_mask',
|
|
2963
2921
|
'past_key_values',
|
|
2964
2922
|
];
|
|
2965
|
-
|
|
2966
|
-
/**
|
|
2967
|
-
* Creates a new instance of the `WhisperPreTrainedModel` class.
|
|
2968
|
-
* @param {Object} config The model configuration.
|
|
2969
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
2970
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
2971
|
-
*/
|
|
2972
|
-
constructor(config, sessions, generation_config) {
|
|
2973
|
-
super(config, sessions);
|
|
2974
|
-
this.generation_config = generation_config;
|
|
2975
|
-
}
|
|
2976
2923
|
};
|
|
2977
2924
|
|
|
2978
2925
|
/**
|
|
@@ -3243,16 +3190,6 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
|
|
|
3243
3190
|
'encoder_hidden_states',
|
|
3244
3191
|
'past_key_values',
|
|
3245
3192
|
];
|
|
3246
|
-
/**
|
|
3247
|
-
* Creates a new instance of the `VisionEncoderDecoderModel` class.
|
|
3248
|
-
* @param {Object} config The model configuration.
|
|
3249
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3250
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3251
|
-
*/
|
|
3252
|
-
constructor(config, sessions, generation_config) {
|
|
3253
|
-
super(config, sessions);
|
|
3254
|
-
this.generation_config = generation_config;
|
|
3255
|
-
}
|
|
3256
3193
|
}
|
|
3257
3194
|
//////////////////////////////////////////////////
|
|
3258
3195
|
|
|
@@ -3267,11 +3204,6 @@ export class LlavaPreTrainedModel extends PreTrainedModel {
|
|
|
3267
3204
|
'position_ids',
|
|
3268
3205
|
'past_key_values',
|
|
3269
3206
|
];
|
|
3270
|
-
|
|
3271
|
-
constructor(config, sessions, generation_config) {
|
|
3272
|
-
super(config, sessions);
|
|
3273
|
-
this.generation_config = generation_config;
|
|
3274
|
-
}
|
|
3275
3207
|
}
|
|
3276
3208
|
|
|
3277
3209
|
/**
|
|
@@ -3358,11 +3290,6 @@ export class Florence2PreTrainedModel extends PreTrainedModel {
|
|
|
3358
3290
|
'past_key_values',
|
|
3359
3291
|
];
|
|
3360
3292
|
main_input_name = 'inputs_embeds';
|
|
3361
|
-
|
|
3362
|
-
constructor(config, sessions, generation_config) {
|
|
3363
|
-
super(config, sessions);
|
|
3364
|
-
this.generation_config = generation_config;
|
|
3365
|
-
}
|
|
3366
3293
|
}
|
|
3367
3294
|
|
|
3368
3295
|
export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel {
|
|
@@ -3501,6 +3428,18 @@ export class CLIPPreTrainedModel extends PreTrainedModel { }
|
|
|
3501
3428
|
*/
|
|
3502
3429
|
export class CLIPModel extends CLIPPreTrainedModel { }
|
|
3503
3430
|
|
|
3431
|
+
/**
|
|
3432
|
+
* The text model from CLIP without any head or projection on top.
|
|
3433
|
+
*/
|
|
3434
|
+
export class CLIPTextModel extends CLIPPreTrainedModel {
|
|
3435
|
+
/** @type {PreTrainedModel.from_pretrained} */
|
|
3436
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3437
|
+
// Update default model file name if not provided
|
|
3438
|
+
options.model_file_name ??= 'text_model';
|
|
3439
|
+
return super.from_pretrained(pretrained_model_name_or_path, options);
|
|
3440
|
+
}
|
|
3441
|
+
}
|
|
3442
|
+
|
|
3504
3443
|
/**
|
|
3505
3444
|
* CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output)
|
|
3506
3445
|
*
|
|
@@ -3528,7 +3467,6 @@ export class CLIPModel extends CLIPPreTrainedModel { }
|
|
|
3528
3467
|
* ```
|
|
3529
3468
|
*/
|
|
3530
3469
|
export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
|
|
3531
|
-
|
|
3532
3470
|
/** @type {PreTrainedModel.from_pretrained} */
|
|
3533
3471
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3534
3472
|
// Update default model file name if not provided
|
|
@@ -3537,6 +3475,18 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
|
|
|
3537
3475
|
}
|
|
3538
3476
|
}
|
|
3539
3477
|
|
|
3478
|
+
/**
|
|
3479
|
+
* The vision model from CLIP without any head or projection on top.
|
|
3480
|
+
*/
|
|
3481
|
+
export class CLIPVisionModel extends CLIPPreTrainedModel {
|
|
3482
|
+
/** @type {PreTrainedModel.from_pretrained} */
|
|
3483
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3484
|
+
// Update default model file name if not provided
|
|
3485
|
+
options.model_file_name ??= 'vision_model';
|
|
3486
|
+
return super.from_pretrained(pretrained_model_name_or_path, options);
|
|
3487
|
+
}
|
|
3488
|
+
}
|
|
3489
|
+
|
|
3540
3490
|
/**
|
|
3541
3491
|
* CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output)
|
|
3542
3492
|
*
|
|
@@ -3759,18 +3709,7 @@ export class CLIPSegForImageSegmentation extends CLIPSegPreTrainedModel { }
|
|
|
3759
3709
|
|
|
3760
3710
|
//////////////////////////////////////////////////
|
|
3761
3711
|
// GPT2 models
|
|
3762
|
-
export class GPT2PreTrainedModel extends PreTrainedModel {
|
|
3763
|
-
/**
|
|
3764
|
-
* Creates a new instance of the `GPT2PreTrainedModel` class.
|
|
3765
|
-
* @param {Object} config The model configuration.
|
|
3766
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3767
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3768
|
-
*/
|
|
3769
|
-
constructor(config, sessions, generation_config) {
|
|
3770
|
-
super(config, sessions);
|
|
3771
|
-
this.generation_config = generation_config;
|
|
3772
|
-
}
|
|
3773
|
-
}
|
|
3712
|
+
export class GPT2PreTrainedModel extends PreTrainedModel { }
|
|
3774
3713
|
|
|
3775
3714
|
export class GPT2Model extends GPT2PreTrainedModel { }
|
|
3776
3715
|
|
|
@@ -3783,20 +3722,25 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { }
|
|
|
3783
3722
|
// }
|
|
3784
3723
|
//////////////////////////////////////////////////
|
|
3785
3724
|
|
|
3725
|
+
//////////////////////////////////////////////////
|
|
3726
|
+
// JAIS models
|
|
3727
|
+
export class JAISPreTrainedModel extends PreTrainedModel { }
|
|
3728
|
+
|
|
3729
|
+
/**
|
|
3730
|
+
* The bare JAIS Model transformer outputting raw hidden-states without any specific head on top.
|
|
3731
|
+
*/
|
|
3732
|
+
export class JAISModel extends JAISPreTrainedModel { }
|
|
3733
|
+
|
|
3734
|
+
/**
|
|
3735
|
+
* The JAIS Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
|
3736
|
+
*/
|
|
3737
|
+
export class JAISLMHeadModel extends JAISPreTrainedModel { }
|
|
3738
|
+
//////////////////////////////////////////////////
|
|
3739
|
+
|
|
3740
|
+
|
|
3786
3741
|
//////////////////////////////////////////////////
|
|
3787
3742
|
// GPTNeo models
|
|
3788
|
-
export class GPTNeoPreTrainedModel extends PreTrainedModel {
|
|
3789
|
-
/**
|
|
3790
|
-
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
|
|
3791
|
-
* @param {Object} config The model configuration.
|
|
3792
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3793
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3794
|
-
*/
|
|
3795
|
-
constructor(config, sessions, generation_config) {
|
|
3796
|
-
super(config, sessions);
|
|
3797
|
-
this.generation_config = generation_config;
|
|
3798
|
-
}
|
|
3799
|
-
}
|
|
3743
|
+
export class GPTNeoPreTrainedModel extends PreTrainedModel { }
|
|
3800
3744
|
export class GPTNeoModel extends GPTNeoPreTrainedModel { }
|
|
3801
3745
|
|
|
3802
3746
|
export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
|
|
@@ -3804,18 +3748,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
|
|
|
3804
3748
|
|
|
3805
3749
|
//////////////////////////////////////////////////
|
|
3806
3750
|
// GPTNeoX models
|
|
3807
|
-
export class GPTNeoXPreTrainedModel extends PreTrainedModel {
|
|
3808
|
-
/**
|
|
3809
|
-
* Creates a new instance of the `GPTNeoXPreTrainedModel` class.
|
|
3810
|
-
* @param {Object} config The model configuration.
|
|
3811
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3812
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3813
|
-
*/
|
|
3814
|
-
constructor(config, sessions, generation_config) {
|
|
3815
|
-
super(config, sessions);
|
|
3816
|
-
this.generation_config = generation_config;
|
|
3817
|
-
}
|
|
3818
|
-
}
|
|
3751
|
+
export class GPTNeoXPreTrainedModel extends PreTrainedModel { }
|
|
3819
3752
|
export class GPTNeoXModel extends GPTNeoXPreTrainedModel { }
|
|
3820
3753
|
|
|
3821
3754
|
export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
|
|
@@ -3824,18 +3757,7 @@ export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
|
|
|
3824
3757
|
|
|
3825
3758
|
//////////////////////////////////////////////////
|
|
3826
3759
|
// GPT-J models
|
|
3827
|
-
export class GPTJPreTrainedModel extends PreTrainedModel {
|
|
3828
|
-
/**
|
|
3829
|
-
* Creates a new instance of the `GPTJPreTrainedModel` class.
|
|
3830
|
-
* @param {Object} config The model configuration.
|
|
3831
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3832
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3833
|
-
*/
|
|
3834
|
-
constructor(config, sessions, generation_config) {
|
|
3835
|
-
super(config, sessions);
|
|
3836
|
-
this.generation_config = generation_config;
|
|
3837
|
-
}
|
|
3838
|
-
}
|
|
3760
|
+
export class GPTJPreTrainedModel extends PreTrainedModel { }
|
|
3839
3761
|
|
|
3840
3762
|
export class GPTJModel extends GPTJPreTrainedModel { }
|
|
3841
3763
|
|
|
@@ -3845,18 +3767,7 @@ export class GPTJForCausalLM extends GPTJPreTrainedModel { }
|
|
|
3845
3767
|
|
|
3846
3768
|
//////////////////////////////////////////////////
|
|
3847
3769
|
// GPTBigCode models
|
|
3848
|
-
export class GPTBigCodePreTrainedModel extends PreTrainedModel {
|
|
3849
|
-
/**
|
|
3850
|
-
* Creates a new instance of the `GPTBigCodePreTrainedModel` class.
|
|
3851
|
-
* @param {Object} config The model configuration.
|
|
3852
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3853
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3854
|
-
*/
|
|
3855
|
-
constructor(config, sessions, generation_config) {
|
|
3856
|
-
super(config, sessions);
|
|
3857
|
-
this.generation_config = generation_config;
|
|
3858
|
-
}
|
|
3859
|
-
}
|
|
3770
|
+
export class GPTBigCodePreTrainedModel extends PreTrainedModel { }
|
|
3860
3771
|
|
|
3861
3772
|
export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { }
|
|
3862
3773
|
|
|
@@ -3865,18 +3776,7 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
|
|
|
3865
3776
|
|
|
3866
3777
|
//////////////////////////////////////////////////
|
|
3867
3778
|
// CodeGen models
|
|
3868
|
-
export class CodeGenPreTrainedModel extends PreTrainedModel {
|
|
3869
|
-
/**
|
|
3870
|
-
* Creates a new instance of the `CodeGenPreTrainedModel` class.
|
|
3871
|
-
* @param {Object} config The model configuration.
|
|
3872
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3873
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3874
|
-
*/
|
|
3875
|
-
constructor(config, sessions, generation_config) {
|
|
3876
|
-
super(config, sessions);
|
|
3877
|
-
this.generation_config = generation_config;
|
|
3878
|
-
}
|
|
3879
|
-
}
|
|
3779
|
+
export class CodeGenPreTrainedModel extends PreTrainedModel { }
|
|
3880
3780
|
/**
|
|
3881
3781
|
* CodeGenModel is a class representing a code generation model without a language model head.
|
|
3882
3782
|
*/
|
|
@@ -3895,18 +3795,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
|
|
|
3895
3795
|
/**
|
|
3896
3796
|
* The bare LLama Model outputting raw hidden-states without any specific head on top.
|
|
3897
3797
|
*/
|
|
3898
|
-
export class LlamaPreTrainedModel extends PreTrainedModel {
|
|
3899
|
-
/**
|
|
3900
|
-
* Creates a new instance of the `LlamaPreTrainedModel` class.
|
|
3901
|
-
* @param {Object} config The model configuration.
|
|
3902
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3903
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3904
|
-
*/
|
|
3905
|
-
constructor(config, sessions, generation_config) {
|
|
3906
|
-
super(config, sessions);
|
|
3907
|
-
this.generation_config = generation_config;
|
|
3908
|
-
}
|
|
3909
|
-
}
|
|
3798
|
+
export class LlamaPreTrainedModel extends PreTrainedModel { }
|
|
3910
3799
|
/**
|
|
3911
3800
|
* The bare LLaMA Model outputting raw hidden-states without any specific head on top.
|
|
3912
3801
|
*/
|
|
@@ -3915,24 +3804,22 @@ export class LlamaModel extends LlamaPreTrainedModel { }
|
|
|
3915
3804
|
export class LlamaForCausalLM extends LlamaPreTrainedModel { }
|
|
3916
3805
|
//////////////////////////////////////////////////
|
|
3917
3806
|
|
|
3807
|
+
|
|
3808
|
+
//////////////////////////////////////////////////
|
|
3809
|
+
// Granite models
|
|
3810
|
+
export class GranitePreTrainedModel extends PreTrainedModel { }
|
|
3811
|
+
export class GraniteModel extends GranitePreTrainedModel { }
|
|
3812
|
+
export class GraniteForCausalLM extends GranitePreTrainedModel { }
|
|
3813
|
+
//////////////////////////////////////////////////
|
|
3814
|
+
|
|
3815
|
+
|
|
3918
3816
|
//////////////////////////////////////////////////
|
|
3919
3817
|
// Cohere models
|
|
3920
3818
|
|
|
3921
3819
|
/**
|
|
3922
3820
|
* The bare Cohere Model outputting raw hidden-states without any specific head on top.
|
|
3923
3821
|
*/
|
|
3924
|
-
export class CoherePreTrainedModel extends PreTrainedModel {
|
|
3925
|
-
/**
|
|
3926
|
-
* Creates a new instance of the `CoherePreTrainedModel` class.
|
|
3927
|
-
* @param {Object} config The model configuration.
|
|
3928
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3929
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3930
|
-
*/
|
|
3931
|
-
constructor(config, sessions, generation_config) {
|
|
3932
|
-
super(config, sessions);
|
|
3933
|
-
this.generation_config = generation_config;
|
|
3934
|
-
}
|
|
3935
|
-
}
|
|
3822
|
+
export class CoherePreTrainedModel extends PreTrainedModel { }
|
|
3936
3823
|
export class CohereModel extends CoherePreTrainedModel { }
|
|
3937
3824
|
|
|
3938
3825
|
export class CohereForCausalLM extends CoherePreTrainedModel { }
|
|
@@ -3944,18 +3831,7 @@ export class CohereForCausalLM extends CoherePreTrainedModel { }
|
|
|
3944
3831
|
/**
|
|
3945
3832
|
* The bare Gemma Model outputting raw hidden-states without any specific head on top.
|
|
3946
3833
|
*/
|
|
3947
|
-
export class GemmaPreTrainedModel extends PreTrainedModel {
|
|
3948
|
-
/**
|
|
3949
|
-
* Creates a new instance of the `GemmaPreTrainedModel` class.
|
|
3950
|
-
* @param {Object} config The model configuration.
|
|
3951
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3952
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3953
|
-
*/
|
|
3954
|
-
constructor(config, sessions, generation_config) {
|
|
3955
|
-
super(config, sessions);
|
|
3956
|
-
this.generation_config = generation_config;
|
|
3957
|
-
}
|
|
3958
|
-
}
|
|
3834
|
+
export class GemmaPreTrainedModel extends PreTrainedModel { }
|
|
3959
3835
|
/**
|
|
3960
3836
|
* The bare Gemma Model outputting raw hidden-states without any specific head on top.
|
|
3961
3837
|
*/
|
|
@@ -3970,18 +3846,7 @@ export class GemmaForCausalLM extends GemmaPreTrainedModel { }
|
|
|
3970
3846
|
/**
|
|
3971
3847
|
* The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
|
|
3972
3848
|
*/
|
|
3973
|
-
export class Gemma2PreTrainedModel extends PreTrainedModel {
|
|
3974
|
-
/**
|
|
3975
|
-
* Creates a new instance of the `Gemma2PreTrainedModel` class.
|
|
3976
|
-
* @param {Object} config The model configuration.
|
|
3977
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3978
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
3979
|
-
*/
|
|
3980
|
-
constructor(config, sessions, generation_config) {
|
|
3981
|
-
super(config, sessions);
|
|
3982
|
-
this.generation_config = generation_config;
|
|
3983
|
-
}
|
|
3984
|
-
}
|
|
3849
|
+
export class Gemma2PreTrainedModel extends PreTrainedModel { }
|
|
3985
3850
|
/**
|
|
3986
3851
|
* The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
|
|
3987
3852
|
*/
|
|
@@ -3991,18 +3856,7 @@ export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
|
|
|
3991
3856
|
//////////////////////////////////////////////////
|
|
3992
3857
|
|
|
3993
3858
|
//////////////////////////////////////////////////
|
|
3994
|
-
export class OpenELMPreTrainedModel extends PreTrainedModel {
|
|
3995
|
-
/**
|
|
3996
|
-
* Creates a new instance of the `OpenELMPreTrainedModel` class.
|
|
3997
|
-
* @param {Object} config The model configuration.
|
|
3998
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
3999
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4000
|
-
*/
|
|
4001
|
-
constructor(config, sessions, generation_config) {
|
|
4002
|
-
super(config, sessions);
|
|
4003
|
-
this.generation_config = generation_config;
|
|
4004
|
-
}
|
|
4005
|
-
}
|
|
3859
|
+
export class OpenELMPreTrainedModel extends PreTrainedModel { }
|
|
4006
3860
|
export class OpenELMModel extends OpenELMPreTrainedModel { }
|
|
4007
3861
|
|
|
4008
3862
|
export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
|
|
@@ -4014,18 +3868,7 @@ export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
|
|
|
4014
3868
|
/**
|
|
4015
3869
|
* The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
|
|
4016
3870
|
*/
|
|
4017
|
-
export class Qwen2PreTrainedModel extends PreTrainedModel {
|
|
4018
|
-
/**
|
|
4019
|
-
* Creates a new instance of the `Qwen2PreTrainedModel` class.
|
|
4020
|
-
* @param {Object} config The model configuration.
|
|
4021
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4022
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4023
|
-
*/
|
|
4024
|
-
constructor(config, sessions, generation_config) {
|
|
4025
|
-
super(config, sessions);
|
|
4026
|
-
this.generation_config = generation_config;
|
|
4027
|
-
}
|
|
4028
|
-
}
|
|
3871
|
+
export class Qwen2PreTrainedModel extends PreTrainedModel { }
|
|
4029
3872
|
/**
|
|
4030
3873
|
* The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
|
|
4031
3874
|
*/
|
|
@@ -4037,18 +3880,7 @@ export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
|
|
|
4037
3880
|
|
|
4038
3881
|
//////////////////////////////////////////////////
|
|
4039
3882
|
// Phi models
|
|
4040
|
-
export class PhiPreTrainedModel extends PreTrainedModel {
|
|
4041
|
-
/**
|
|
4042
|
-
* Creates a new instance of the `PhiPreTrainedModel` class.
|
|
4043
|
-
* @param {Object} config The model configuration.
|
|
4044
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4045
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4046
|
-
*/
|
|
4047
|
-
constructor(config, sessions, generation_config) {
|
|
4048
|
-
super(config, sessions);
|
|
4049
|
-
this.generation_config = generation_config;
|
|
4050
|
-
}
|
|
4051
|
-
}
|
|
3883
|
+
export class PhiPreTrainedModel extends PreTrainedModel { }
|
|
4052
3884
|
/**
|
|
4053
3885
|
* The bare Phi Model outputting raw hidden-states without any specific head on top.
|
|
4054
3886
|
*/
|
|
@@ -4059,18 +3891,7 @@ export class PhiForCausalLM extends PhiPreTrainedModel { }
|
|
|
4059
3891
|
|
|
4060
3892
|
//////////////////////////////////////////////////
|
|
4061
3893
|
// Phi3 models
|
|
4062
|
-
export class Phi3PreTrainedModel extends PreTrainedModel {
|
|
4063
|
-
/**
|
|
4064
|
-
* Creates a new instance of the `Phi3PreTrainedModel` class.
|
|
4065
|
-
* @param {Object} config The model configuration.
|
|
4066
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4067
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4068
|
-
*/
|
|
4069
|
-
constructor(config, sessions, generation_config) {
|
|
4070
|
-
super(config, sessions);
|
|
4071
|
-
this.generation_config = generation_config;
|
|
4072
|
-
}
|
|
4073
|
-
}
|
|
3894
|
+
export class Phi3PreTrainedModel extends PreTrainedModel { }
|
|
4074
3895
|
|
|
4075
3896
|
/**
|
|
4076
3897
|
* The bare Phi3 Model outputting raw hidden-states without any specific head on top.
|
|
@@ -4086,18 +3907,7 @@ export class Phi3ForCausalLM extends Phi3PreTrainedModel { }
|
|
|
4086
3907
|
/**
|
|
4087
3908
|
* The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
|
4088
3909
|
*/
|
|
4089
|
-
export class BloomPreTrainedModel extends PreTrainedModel {
|
|
4090
|
-
/**
|
|
4091
|
-
* Creates a new instance of the `BloomPreTrainedModel` class.
|
|
4092
|
-
* @param {Object} config The model configuration.
|
|
4093
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4094
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4095
|
-
*/
|
|
4096
|
-
constructor(config, sessions, generation_config) {
|
|
4097
|
-
super(config, sessions);
|
|
4098
|
-
this.generation_config = generation_config;
|
|
4099
|
-
}
|
|
4100
|
-
}
|
|
3910
|
+
export class BloomPreTrainedModel extends PreTrainedModel { }
|
|
4101
3911
|
|
|
4102
3912
|
/**
|
|
4103
3913
|
* The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -4112,18 +3922,7 @@ export class BloomForCausalLM extends BloomPreTrainedModel { }
|
|
|
4112
3922
|
|
|
4113
3923
|
//////////////////////////////////////////////////
|
|
4114
3924
|
// MPT models
|
|
4115
|
-
export class MptPreTrainedModel extends PreTrainedModel {
|
|
4116
|
-
/**
|
|
4117
|
-
* Creates a new instance of the `MptPreTrainedModel` class.
|
|
4118
|
-
* @param {Object} config The model configuration.
|
|
4119
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4120
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4121
|
-
*/
|
|
4122
|
-
constructor(config, sessions, generation_config) {
|
|
4123
|
-
super(config, sessions);
|
|
4124
|
-
this.generation_config = generation_config;
|
|
4125
|
-
}
|
|
4126
|
-
}
|
|
3925
|
+
export class MptPreTrainedModel extends PreTrainedModel { }
|
|
4127
3926
|
|
|
4128
3927
|
/**
|
|
4129
3928
|
* The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -4139,18 +3938,7 @@ export class MptForCausalLM extends MptPreTrainedModel { }
|
|
|
4139
3938
|
|
|
4140
3939
|
//////////////////////////////////////////////////
|
|
4141
3940
|
// OPT models
|
|
4142
|
-
export class OPTPreTrainedModel extends PreTrainedModel {
|
|
4143
|
-
/**
|
|
4144
|
-
* Creates a new instance of the `OPTPreTrainedModel` class.
|
|
4145
|
-
* @param {Object} config The model configuration.
|
|
4146
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4147
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4148
|
-
*/
|
|
4149
|
-
constructor(config, sessions, generation_config) {
|
|
4150
|
-
super(config, sessions);
|
|
4151
|
-
this.generation_config = generation_config;
|
|
4152
|
-
}
|
|
4153
|
-
}
|
|
3941
|
+
export class OPTPreTrainedModel extends PreTrainedModel { }
|
|
4154
3942
|
|
|
4155
3943
|
/**
|
|
4156
3944
|
* The bare OPT Model outputting raw hidden-states without any specific head on top.
|
|
@@ -4176,6 +3964,43 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
|
|
|
4176
3964
|
}
|
|
4177
3965
|
//////////////////////////////////////////////////
|
|
4178
3966
|
|
|
3967
|
+
//////////////////////////////////////////////////
|
|
3968
|
+
export class PvtPreTrainedModel extends PreTrainedModel { }
|
|
3969
|
+
export class PvtModel extends PvtPreTrainedModel { }
|
|
3970
|
+
export class PvtForImageClassification extends PvtPreTrainedModel {
|
|
3971
|
+
/**
|
|
3972
|
+
* @param {any} model_inputs
|
|
3973
|
+
*/
|
|
3974
|
+
async _call(model_inputs) {
|
|
3975
|
+
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
3976
|
+
}
|
|
3977
|
+
}
|
|
3978
|
+
//////////////////////////////////////////////////
|
|
3979
|
+
|
|
3980
|
+
//////////////////////////////////////////////////
|
|
3981
|
+
export class ViTMAEPreTrainedModel extends PreTrainedModel { }
|
|
3982
|
+
export class ViTMAEModel extends ViTMAEPreTrainedModel { }
|
|
3983
|
+
//////////////////////////////////////////////////
|
|
3984
|
+
|
|
3985
|
+
|
|
3986
|
+
//////////////////////////////////////////////////
|
|
3987
|
+
export class ViTMSNPreTrainedModel extends PreTrainedModel { }
|
|
3988
|
+
export class ViTMSNModel extends ViTMSNPreTrainedModel { }
|
|
3989
|
+
export class ViTMSNForImageClassification extends ViTMSNPreTrainedModel {
|
|
3990
|
+
/**
|
|
3991
|
+
* @param {any} model_inputs
|
|
3992
|
+
*/
|
|
3993
|
+
async _call(model_inputs) {
|
|
3994
|
+
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
3995
|
+
}
|
|
3996
|
+
}
|
|
3997
|
+
//////////////////////////////////////////////////
|
|
3998
|
+
|
|
3999
|
+
//////////////////////////////////////////////////
|
|
4000
|
+
export class GroupViTPreTrainedModel extends PreTrainedModel { }
|
|
4001
|
+
export class GroupViTModel extends GroupViTPreTrainedModel { }
|
|
4002
|
+
//////////////////////////////////////////////////
|
|
4003
|
+
|
|
4179
4004
|
|
|
4180
4005
|
//////////////////////////////////////////////////
|
|
4181
4006
|
export class FastViTPreTrainedModel extends PreTrainedModel { }
|
|
@@ -4429,6 +4254,19 @@ export class DeiTForImageClassification extends DeiTPreTrainedModel {
|
|
|
4429
4254
|
}
|
|
4430
4255
|
//////////////////////////////////////////////////
|
|
4431
4256
|
|
|
4257
|
+
//////////////////////////////////////////////////
|
|
4258
|
+
export class HieraPreTrainedModel extends PreTrainedModel { }
|
|
4259
|
+
export class HieraModel extends HieraPreTrainedModel { }
|
|
4260
|
+
export class HieraForImageClassification extends HieraPreTrainedModel {
|
|
4261
|
+
/**
|
|
4262
|
+
* @param {any} model_inputs
|
|
4263
|
+
*/
|
|
4264
|
+
async _call(model_inputs) {
|
|
4265
|
+
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
4266
|
+
}
|
|
4267
|
+
}
|
|
4268
|
+
//////////////////////////////////////////////////
|
|
4269
|
+
|
|
4432
4270
|
|
|
4433
4271
|
//////////////////////////////////////////////////
|
|
4434
4272
|
/**
|
|
@@ -4568,6 +4406,24 @@ export class DepthAnythingForDepthEstimation extends DepthAnythingPreTrainedMode
|
|
|
4568
4406
|
//////////////////////////////////////////////////
|
|
4569
4407
|
|
|
4570
4408
|
|
|
4409
|
+
//////////////////////////////////////////////////
|
|
4410
|
+
export class SapiensPreTrainedModel extends PreTrainedModel { }
|
|
4411
|
+
export class SapiensForSemanticSegmentation extends SapiensPreTrainedModel { }
|
|
4412
|
+
export class SapiensForDepthEstimation extends SapiensPreTrainedModel { }
|
|
4413
|
+
export class SapiensForNormalEstimation extends SapiensPreTrainedModel { }
|
|
4414
|
+
//////////////////////////////////////////////////
|
|
4415
|
+
|
|
4416
|
+
//////////////////////////////////////////////////
|
|
4417
|
+
export class DepthProPreTrainedModel extends PreTrainedModel { }
|
|
4418
|
+
export class DepthProForDepthEstimation extends DepthProPreTrainedModel { }
|
|
4419
|
+
//////////////////////////////////////////////////
|
|
4420
|
+
|
|
4421
|
+
//////////////////////////////////////////////////
|
|
4422
|
+
export class MaskFormerPreTrainedModel extends PreTrainedModel { }
|
|
4423
|
+
export class MaskFormerModel extends MaskFormerPreTrainedModel { }
|
|
4424
|
+
export class MaskFormerForInstanceSegmentation extends MaskFormerPreTrainedModel { }
|
|
4425
|
+
//////////////////////////////////////////////////
|
|
4426
|
+
|
|
4571
4427
|
//////////////////////////////////////////////////
|
|
4572
4428
|
export class GLPNPreTrainedModel extends PreTrainedModel { }
|
|
4573
4429
|
|
|
@@ -4944,19 +4800,7 @@ export class SamImageSegmentationOutput extends ModelOutput {
|
|
|
4944
4800
|
|
|
4945
4801
|
//////////////////////////////////////////////////
|
|
4946
4802
|
// MarianMT models
|
|
4947
|
-
export class MarianPreTrainedModel extends PreTrainedModel {
|
|
4948
|
-
|
|
4949
|
-
/**
|
|
4950
|
-
* Creates a new instance of the `MarianMTModel` class.
|
|
4951
|
-
* @param {Object} config The model configuration.
|
|
4952
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4953
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4954
|
-
*/
|
|
4955
|
-
constructor(config, sessions, generation_config) {
|
|
4956
|
-
super(config, sessions);
|
|
4957
|
-
this.generation_config = generation_config;
|
|
4958
|
-
}
|
|
4959
|
-
};
|
|
4803
|
+
export class MarianPreTrainedModel extends PreTrainedModel { };
|
|
4960
4804
|
|
|
4961
4805
|
export class MarianModel extends MarianPreTrainedModel { }
|
|
4962
4806
|
|
|
@@ -4965,19 +4809,7 @@ export class MarianMTModel extends MarianPreTrainedModel { }
|
|
|
4965
4809
|
|
|
4966
4810
|
//////////////////////////////////////////////////
|
|
4967
4811
|
// M2M100 models
|
|
4968
|
-
export class M2M100PreTrainedModel extends PreTrainedModel {
|
|
4969
|
-
|
|
4970
|
-
/**
|
|
4971
|
-
* Creates a new instance of the `M2M100ForConditionalGeneration` class.
|
|
4972
|
-
* @param {Object} config The model configuration.
|
|
4973
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
4974
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
4975
|
-
*/
|
|
4976
|
-
constructor(config, sessions, generation_config) {
|
|
4977
|
-
super(config, sessions);
|
|
4978
|
-
this.generation_config = generation_config;
|
|
4979
|
-
}
|
|
4980
|
-
};
|
|
4812
|
+
export class M2M100PreTrainedModel extends PreTrainedModel { };
|
|
4981
4813
|
|
|
4982
4814
|
export class M2M100Model extends M2M100PreTrainedModel { }
|
|
4983
4815
|
|
|
@@ -5069,7 +4901,7 @@ export class PyAnnoteModel extends PyAnnotePreTrainedModel { }
|
|
|
5069
4901
|
* **Example:** Load and run a `PyAnnoteForAudioFrameClassification` for speaker diarization.
|
|
5070
4902
|
*
|
|
5071
4903
|
* ```javascript
|
|
5072
|
-
* import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@
|
|
4904
|
+
* import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@huggingface/transformers';
|
|
5073
4905
|
*
|
|
5074
4906
|
* // Load model and processor
|
|
5075
4907
|
* const model_id = 'onnx-community/pyannote-segmentation-3.0';
|
|
@@ -5487,19 +5319,7 @@ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
|
|
|
5487
5319
|
/**
|
|
5488
5320
|
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
|
5489
5321
|
*/
|
|
5490
|
-
export class SpeechT5PreTrainedModel extends PreTrainedModel {
|
|
5491
|
-
|
|
5492
|
-
/**
|
|
5493
|
-
* Creates a new instance of the `SpeechT5ForTextToSpeech` class.
|
|
5494
|
-
* @param {Object} config The model configuration.
|
|
5495
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
5496
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5497
|
-
*/
|
|
5498
|
-
constructor(config, sessions, generation_config) {
|
|
5499
|
-
super(config, sessions);
|
|
5500
|
-
this.generation_config = generation_config;
|
|
5501
|
-
}
|
|
5502
|
-
};
|
|
5322
|
+
export class SpeechT5PreTrainedModel extends PreTrainedModel { };
|
|
5503
5323
|
|
|
5504
5324
|
/**
|
|
5505
5325
|
* The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
|
|
@@ -5660,18 +5480,7 @@ export class SpeechT5HifiGan extends PreTrainedModel {
|
|
|
5660
5480
|
|
|
5661
5481
|
//////////////////////////////////////////////////
|
|
5662
5482
|
// TrOCR models
|
|
5663
|
-
export class TrOCRPreTrainedModel extends PreTrainedModel {
|
|
5664
|
-
/**
|
|
5665
|
-
* Creates a new instance of the `TrOCRPreTrainedModel` class.
|
|
5666
|
-
* @param {Object} config The configuration of the model.
|
|
5667
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5668
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5669
|
-
*/
|
|
5670
|
-
constructor(config, session, generation_config) {
|
|
5671
|
-
super(config, session);
|
|
5672
|
-
this.generation_config = generation_config;
|
|
5673
|
-
}
|
|
5674
|
-
}
|
|
5483
|
+
export class TrOCRPreTrainedModel extends PreTrainedModel { }
|
|
5675
5484
|
|
|
5676
5485
|
/**
|
|
5677
5486
|
* The TrOCR Decoder with a language modeling head.
|
|
@@ -5686,18 +5495,7 @@ export class TrOCRForCausalLM extends TrOCRPreTrainedModel { }
|
|
|
5686
5495
|
/**
|
|
5687
5496
|
* The bare Mistral Model outputting raw hidden-states without any specific head on top.
|
|
5688
5497
|
*/
|
|
5689
|
-
export class MistralPreTrainedModel extends PreTrainedModel {
|
|
5690
|
-
/**
|
|
5691
|
-
* Creates a new instance of the `MistralPreTrainedModel` class.
|
|
5692
|
-
* @param {Object} config The configuration of the model.
|
|
5693
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5694
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5695
|
-
*/
|
|
5696
|
-
constructor(config, session, generation_config) {
|
|
5697
|
-
super(config, session);
|
|
5698
|
-
this.generation_config = generation_config;
|
|
5699
|
-
}
|
|
5700
|
-
}
|
|
5498
|
+
export class MistralPreTrainedModel extends PreTrainedModel { }
|
|
5701
5499
|
|
|
5702
5500
|
export class MistralModel extends MistralPreTrainedModel { }
|
|
5703
5501
|
|
|
@@ -5710,18 +5508,7 @@ export class MistralForCausalLM extends MistralPreTrainedModel { }
|
|
|
5710
5508
|
/**
|
|
5711
5509
|
* The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.
|
|
5712
5510
|
*/
|
|
5713
|
-
export class Starcoder2PreTrainedModel extends PreTrainedModel {
|
|
5714
|
-
/**
|
|
5715
|
-
* Creates a new instance of the `Starcoder2PreTrainedModel` class.
|
|
5716
|
-
* @param {Object} config The configuration of the model.
|
|
5717
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5718
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5719
|
-
*/
|
|
5720
|
-
constructor(config, session, generation_config) {
|
|
5721
|
-
super(config, session);
|
|
5722
|
-
this.generation_config = generation_config;
|
|
5723
|
-
}
|
|
5724
|
-
}
|
|
5511
|
+
export class Starcoder2PreTrainedModel extends PreTrainedModel { }
|
|
5725
5512
|
|
|
5726
5513
|
export class Starcoder2Model extends Starcoder2PreTrainedModel { }
|
|
5727
5514
|
|
|
@@ -5734,18 +5521,7 @@ export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel { }
|
|
|
5734
5521
|
/**
|
|
5735
5522
|
* The bare Falcon Model outputting raw hidden-states without any specific head on top.
|
|
5736
5523
|
*/
|
|
5737
|
-
export class FalconPreTrainedModel extends PreTrainedModel {
|
|
5738
|
-
/**
|
|
5739
|
-
* Creates a new instance of the `FalconPreTrainedModel` class.
|
|
5740
|
-
* @param {Object} config The configuration of the model.
|
|
5741
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5742
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5743
|
-
*/
|
|
5744
|
-
constructor(config, session, generation_config) {
|
|
5745
|
-
super(config, session);
|
|
5746
|
-
this.generation_config = generation_config;
|
|
5747
|
-
}
|
|
5748
|
-
}
|
|
5524
|
+
export class FalconPreTrainedModel extends PreTrainedModel { }
|
|
5749
5525
|
|
|
5750
5526
|
export class FalconModel extends FalconPreTrainedModel { }
|
|
5751
5527
|
|
|
@@ -5895,18 +5671,7 @@ export class SegformerForSemanticSegmentation extends SegformerPreTrainedModel {
|
|
|
5895
5671
|
|
|
5896
5672
|
//////////////////////////////////////////////////
|
|
5897
5673
|
// StableLm models
|
|
5898
|
-
export class StableLmPreTrainedModel extends PreTrainedModel {
|
|
5899
|
-
/**
|
|
5900
|
-
* Creates a new instance of the `StableLmPreTrainedModel` class.
|
|
5901
|
-
* @param {Object} config The configuration of the model.
|
|
5902
|
-
* @param {any} session The ONNX session containing the model weights.
|
|
5903
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
5904
|
-
*/
|
|
5905
|
-
constructor(config, session, generation_config) {
|
|
5906
|
-
super(config, session);
|
|
5907
|
-
this.generation_config = generation_config;
|
|
5908
|
-
}
|
|
5909
|
-
}
|
|
5674
|
+
export class StableLmPreTrainedModel extends PreTrainedModel { }
|
|
5910
5675
|
|
|
5911
5676
|
/**
|
|
5912
5677
|
* The bare StableLm Model transformer outputting raw hidden-states without any specific head on top.
|
|
@@ -6000,17 +5765,6 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE:
|
|
|
6000
5765
|
'past_key_values',
|
|
6001
5766
|
];
|
|
6002
5767
|
|
|
6003
|
-
/**
|
|
6004
|
-
* Creates a new instance of the `MusicgenForConditionalGeneration` class.
|
|
6005
|
-
* @param {Object} config The model configuration.
|
|
6006
|
-
* @param {Record<string, any>} sessions The inference sessions for the model.
|
|
6007
|
-
* @param {GenerationConfig} generation_config The generation configuration.
|
|
6008
|
-
*/
|
|
6009
|
-
constructor(config, sessions, generation_config) {
|
|
6010
|
-
super(config, sessions);
|
|
6011
|
-
this.generation_config = generation_config;
|
|
6012
|
-
}
|
|
6013
|
-
|
|
6014
5768
|
/**
|
|
6015
5769
|
* Apply the pattern mask to the final ids,
|
|
6016
5770
|
* then revert the pattern delay mask by filtering the pad token id in a single step.
|
|
@@ -6089,6 +5843,7 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE:
|
|
|
6089
5843
|
return audio_values;
|
|
6090
5844
|
}
|
|
6091
5845
|
}
|
|
5846
|
+
//////////////////////////////////////////////////
|
|
6092
5847
|
|
|
6093
5848
|
//////////////////////////////////////////////////
|
|
6094
5849
|
// MobileNetV1 models
|
|
@@ -6182,6 +5937,17 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
|
|
|
6182
5937
|
}
|
|
6183
5938
|
//////////////////////////////////////////////////
|
|
6184
5939
|
|
|
5940
|
+
//////////////////////////////////////////////////
|
|
5941
|
+
// Decision Transformer models
|
|
5942
|
+
export class DecisionTransformerPreTrainedModel extends PreTrainedModel { }
|
|
5943
|
+
|
|
5944
|
+
/**
|
|
5945
|
+
* The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL setting.
|
|
5946
|
+
* Refer to the paper for more details: https://arxiv.org/abs/2106.01345
|
|
5947
|
+
*/
|
|
5948
|
+
export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel { }
|
|
5949
|
+
|
|
5950
|
+
//////////////////////////////////////////////////
|
|
6185
5951
|
|
|
6186
5952
|
//////////////////////////////////////////////////
|
|
6187
5953
|
// AutoModels, used to simplify construction of PreTrainedModels
|
|
@@ -6220,7 +5986,7 @@ export class PretrainedMixin {
|
|
|
6220
5986
|
session_options = {},
|
|
6221
5987
|
} = {}) {
|
|
6222
5988
|
|
|
6223
|
-
|
|
5989
|
+
const options = {
|
|
6224
5990
|
progress_callback,
|
|
6225
5991
|
config,
|
|
6226
5992
|
cache_dir,
|
|
@@ -6239,7 +6005,7 @@ export class PretrainedMixin {
|
|
|
6239
6005
|
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
|
|
6240
6006
|
}
|
|
6241
6007
|
|
|
6242
|
-
for (
|
|
6008
|
+
for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
|
|
6243
6009
|
const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
|
|
6244
6010
|
if (!modelInfo) {
|
|
6245
6011
|
continue; // Item not found in this mapping
|
|
@@ -6294,6 +6060,10 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6294
6060
|
['rt_detr', ['RTDetrModel', RTDetrModel]],
|
|
6295
6061
|
['table-transformer', ['TableTransformerModel', TableTransformerModel]],
|
|
6296
6062
|
['vit', ['ViTModel', ViTModel]],
|
|
6063
|
+
['pvt', ['PvtModel', PvtModel]],
|
|
6064
|
+
['vit_msn', ['ViTMSNModel', ViTMSNModel]],
|
|
6065
|
+
['vit_mae', ['ViTMAEModel', ViTMAEModel]],
|
|
6066
|
+
['groupvit', ['GroupViTModel', GroupViTModel]],
|
|
6297
6067
|
['fastvit', ['FastViTModel', FastViTModel]],
|
|
6298
6068
|
['mobilevit', ['MobileViTModel', MobileViTModel]],
|
|
6299
6069
|
['mobilevitv2', ['MobileViTV2Model', MobileViTV2Model]],
|
|
@@ -6301,6 +6071,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6301
6071
|
['owlv2', ['Owlv2Model', Owlv2Model]],
|
|
6302
6072
|
['beit', ['BeitModel', BeitModel]],
|
|
6303
6073
|
['deit', ['DeiTModel', DeiTModel]],
|
|
6074
|
+
['hiera', ['HieraModel', HieraModel]],
|
|
6304
6075
|
['convnext', ['ConvNextModel', ConvNextModel]],
|
|
6305
6076
|
['convnextv2', ['ConvNextV2Model', ConvNextV2Model]],
|
|
6306
6077
|
['dinov2', ['Dinov2Model', Dinov2Model]],
|
|
@@ -6315,10 +6086,14 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6315
6086
|
['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],
|
|
6316
6087
|
['efficientnet', ['EfficientNetModel', EfficientNetModel]],
|
|
6317
6088
|
|
|
6089
|
+
['decision_transformer', ['DecisionTransformerModel', DecisionTransformerModel]],
|
|
6090
|
+
|
|
6318
6091
|
['mobilenet_v1', ['MobileNetV1Model', MobileNetV1Model]],
|
|
6319
6092
|
['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]],
|
|
6320
6093
|
['mobilenet_v3', ['MobileNetV3Model', MobileNetV3Model]],
|
|
6321
6094
|
['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]],
|
|
6095
|
+
|
|
6096
|
+
['maskformer', ['MaskFormerModel', MaskFormerModel]],
|
|
6322
6097
|
]);
|
|
6323
6098
|
|
|
6324
6099
|
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
|
|
@@ -6337,6 +6112,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
|
|
|
6337
6112
|
|
|
6338
6113
|
const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
|
|
6339
6114
|
['bloom', ['BloomModel', BloomModel]],
|
|
6115
|
+
['jais', ['JAISModel', JAISModel]],
|
|
6340
6116
|
['gpt2', ['GPT2Model', GPT2Model]],
|
|
6341
6117
|
['gptj', ['GPTJModel', GPTJModel]],
|
|
6342
6118
|
['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]],
|
|
@@ -6344,6 +6120,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
|
|
|
6344
6120
|
['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]],
|
|
6345
6121
|
['codegen', ['CodeGenModel', CodeGenModel]],
|
|
6346
6122
|
['llama', ['LlamaModel', LlamaModel]],
|
|
6123
|
+
['granite', ['GraniteModel', GraniteModel]],
|
|
6347
6124
|
['cohere', ['CohereModel', CohereModel]],
|
|
6348
6125
|
['gemma', ['GemmaModel', GemmaModel]],
|
|
6349
6126
|
['gemma2', ['Gemma2Model', Gemma2Model]],
|
|
@@ -6425,12 +6202,14 @@ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
|
|
|
6425
6202
|
const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
|
|
6426
6203
|
['bloom', ['BloomForCausalLM', BloomForCausalLM]],
|
|
6427
6204
|
['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]],
|
|
6205
|
+
['jais', ['JAISLMHeadModel', JAISLMHeadModel]],
|
|
6428
6206
|
['gptj', ['GPTJForCausalLM', GPTJForCausalLM]],
|
|
6429
6207
|
['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]],
|
|
6430
6208
|
['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]],
|
|
6431
6209
|
['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]],
|
|
6432
6210
|
['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]],
|
|
6433
6211
|
['llama', ['LlamaForCausalLM', LlamaForCausalLM]],
|
|
6212
|
+
['granite', ['GraniteForCausalLM', GraniteForCausalLM]],
|
|
6434
6213
|
['cohere', ['CohereForCausalLM', CohereForCausalLM]],
|
|
6435
6214
|
['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
|
|
6436
6215
|
['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
|
|
@@ -6501,11 +6280,14 @@ const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
|
|
|
6501
6280
|
|
|
6502
6281
|
const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([
|
|
6503
6282
|
['vit', ['ViTForImageClassification', ViTForImageClassification]],
|
|
6283
|
+
['pvt', ['PvtForImageClassification', PvtForImageClassification]],
|
|
6284
|
+
['vit_msn', ['ViTMSNForImageClassification', ViTMSNForImageClassification]],
|
|
6504
6285
|
['fastvit', ['FastViTForImageClassification', FastViTForImageClassification]],
|
|
6505
6286
|
['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]],
|
|
6506
6287
|
['mobilevitv2', ['MobileViTV2ForImageClassification', MobileViTV2ForImageClassification]],
|
|
6507
6288
|
['beit', ['BeitForImageClassification', BeitForImageClassification]],
|
|
6508
6289
|
['deit', ['DeiTForImageClassification', DeiTForImageClassification]],
|
|
6290
|
+
['hiera', ['HieraForImageClassification', HieraForImageClassification]],
|
|
6509
6291
|
['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]],
|
|
6510
6292
|
['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]],
|
|
6511
6293
|
['dinov2', ['Dinov2ForImageClassification', Dinov2ForImageClassification]],
|
|
@@ -6532,12 +6314,19 @@ const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
|
|
|
6532
6314
|
]);
|
|
6533
6315
|
|
|
6534
6316
|
const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
6317
|
+
// TODO: Do not add new models here
|
|
6535
6318
|
['detr', ['DetrForSegmentation', DetrForSegmentation]],
|
|
6536
6319
|
['clipseg', ['CLIPSegForImageSegmentation', CLIPSegForImageSegmentation]],
|
|
6537
6320
|
]);
|
|
6538
6321
|
|
|
6539
6322
|
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
6540
6323
|
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
|
|
6324
|
+
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
|
|
6325
|
+
]);
|
|
6326
|
+
|
|
6327
|
+
const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
6328
|
+
['detr', ['DetrForSegmentation', DetrForSegmentation]],
|
|
6329
|
+
['maskformer', ['MaskFormerForInstanceSegmentation', MaskFormerForInstanceSegmentation]],
|
|
6541
6330
|
]);
|
|
6542
6331
|
|
|
6543
6332
|
const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
|
|
@@ -6586,6 +6375,12 @@ const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
|
|
|
6586
6375
|
['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]],
|
|
6587
6376
|
['depth_anything', ['DepthAnythingForDepthEstimation', DepthAnythingForDepthEstimation]],
|
|
6588
6377
|
['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
|
|
6378
|
+
['sapiens', ['SapiensForDepthEstimation', SapiensForDepthEstimation]],
|
|
6379
|
+
['depth_pro', ['DepthProForDepthEstimation', DepthProForDepthEstimation]],
|
|
6380
|
+
])
|
|
6381
|
+
|
|
6382
|
+
const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
|
|
6383
|
+
['sapiens', ['SapiensForNormalEstimation', SapiensForNormalEstimation]],
|
|
6589
6384
|
])
|
|
6590
6385
|
|
|
6591
6386
|
// NOTE: This is custom to Transformers.js, and is necessary because certain models
|
|
@@ -6610,10 +6405,12 @@ const MODEL_CLASS_TYPE_MAPPING = [
|
|
|
6610
6405
|
[MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText],
|
|
6611
6406
|
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6612
6407
|
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6408
|
+
[MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6613
6409
|
[MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6614
6410
|
[MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6615
6411
|
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6616
6412
|
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6413
|
+
[MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6617
6414
|
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6618
6415
|
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6619
6416
|
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
|
|
@@ -6811,6 +6608,17 @@ export class AutoModelForSemanticSegmentation extends PretrainedMixin {
|
|
|
6811
6608
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES];
|
|
6812
6609
|
}
|
|
6813
6610
|
|
|
6611
|
+
/**
|
|
6612
|
+
* Helper class which is used to instantiate pretrained universal image segmentation models with the `from_pretrained` function.
|
|
6613
|
+
* The chosen model class is determined by the type specified in the model config.
|
|
6614
|
+
*
|
|
6615
|
+
* @example
|
|
6616
|
+
* let model = await AutoModelForUniversalSegmentation.from_pretrained('hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation');
|
|
6617
|
+
*/
|
|
6618
|
+
export class AutoModelForUniversalSegmentation extends PretrainedMixin {
|
|
6619
|
+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES];
|
|
6620
|
+
}
|
|
6621
|
+
|
|
6814
6622
|
/**
|
|
6815
6623
|
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
|
|
6816
6624
|
* The chosen model class is determined by the type specified in the model config.
|
|
@@ -6870,6 +6678,10 @@ export class AutoModelForDepthEstimation extends PretrainedMixin {
|
|
|
6870
6678
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES];
|
|
6871
6679
|
}
|
|
6872
6680
|
|
|
6681
|
+
export class AutoModelForNormalEstimation extends PretrainedMixin {
|
|
6682
|
+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES];
|
|
6683
|
+
}
|
|
6684
|
+
|
|
6873
6685
|
export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
|
|
6874
6686
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
|
|
6875
6687
|
}
|