transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,138 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BaseAutoModelClass
|
17
|
+
extend ClassAttribute
|
18
|
+
|
19
|
+
class_attribute :_model_mapping
|
20
|
+
|
21
|
+
class << self
|
22
|
+
def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
23
|
+
config = kwargs.delete(:config)
|
24
|
+
trust_remote_code = kwargs.delete(:trust_remote_code)
|
25
|
+
hub_kwargs_names = [
|
26
|
+
:cache_dir,
|
27
|
+
:force_download,
|
28
|
+
:local_files_only,
|
29
|
+
:proxies,
|
30
|
+
:resume_download,
|
31
|
+
:revision,
|
32
|
+
:subfolder,
|
33
|
+
:use_auth_token,
|
34
|
+
:token
|
35
|
+
]
|
36
|
+
hub_kwargs = hub_kwargs_names.select { |k| kwargs.key?(k) }.to_h { |name| [name, kwargs.delete(name)] }
|
37
|
+
code_revision = kwargs.delete(:code_revision)
|
38
|
+
commit_hash = kwargs.delete(:_commit_hash)
|
39
|
+
|
40
|
+
if commit_hash.nil?
|
41
|
+
if !config.is_a?(PretrainedConfig)
|
42
|
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
43
|
+
resolved_config_file = Utils::Hub.cached_file(
|
44
|
+
pretrained_model_name_or_path,
|
45
|
+
CONFIG_NAME,
|
46
|
+
_raise_exceptions_for_gated_repo: false,
|
47
|
+
_raise_exceptions_for_missing_entries: false,
|
48
|
+
_raise_exceptions_for_connection_errors: false,
|
49
|
+
**hub_kwargs
|
50
|
+
)
|
51
|
+
commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
52
|
+
else
|
53
|
+
raise Todo
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
if !config.is_a?(PretrainedConfig)
|
58
|
+
config, kwargs =
|
59
|
+
AutoConfig.from_pretrained(
|
60
|
+
pretrained_model_name_or_path,
|
61
|
+
return_unused_kwargs: true,
|
62
|
+
trust_remote_code: trust_remote_code,
|
63
|
+
code_revision: code_revision,
|
64
|
+
_commit_hash: commit_hash,
|
65
|
+
**hub_kwargs,
|
66
|
+
**kwargs
|
67
|
+
)
|
68
|
+
end
|
69
|
+
|
70
|
+
model_class = _get_model_class(config, _model_mapping)
|
71
|
+
model_class.from_pretrained(
|
72
|
+
pretrained_model_name_or_path, *model_args, config: config, **hub_kwargs, **kwargs
|
73
|
+
)
|
74
|
+
end
|
75
|
+
|
76
|
+
private
|
77
|
+
|
78
|
+
def _get_model_class(config, model_mapping)
|
79
|
+
supported_models = model_mapping[config.class.name.split("::").last]
|
80
|
+
if !supported_models.is_a?(Array)
|
81
|
+
return supported_models
|
82
|
+
end
|
83
|
+
|
84
|
+
raise Todo
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
class LazyAutoMapping
|
90
|
+
def initialize(config_mapping, model_mapping)
|
91
|
+
@config_mapping = config_mapping
|
92
|
+
@reverse_config_mapping = config_mapping.invert
|
93
|
+
@model_mapping = model_mapping
|
94
|
+
@modules = {}
|
95
|
+
end
|
96
|
+
|
97
|
+
def [](key)
|
98
|
+
model_type = @reverse_config_mapping[key]
|
99
|
+
if @model_mapping[model_type]
|
100
|
+
model_name = @model_mapping[model_type]
|
101
|
+
return _load_attr_from_module(model_type, model_name)
|
102
|
+
end
|
103
|
+
|
104
|
+
raise KeyError, key
|
105
|
+
end
|
106
|
+
|
107
|
+
def include?(key)
|
108
|
+
self[key]
|
109
|
+
true
|
110
|
+
rescue KeyError
|
111
|
+
false
|
112
|
+
end
|
113
|
+
|
114
|
+
private
|
115
|
+
|
116
|
+
def _load_attr_from_module(model_type, attr)
|
117
|
+
module_name = model_type_to_module_name(model_type)
|
118
|
+
if !@modules.include?(module_name)
|
119
|
+
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
120
|
+
end
|
121
|
+
getattribute_from_module(@modules[module_name], attr)
|
122
|
+
end
|
123
|
+
|
124
|
+
def getattribute_from_module(mod, attr)
|
125
|
+
if attr.nil?
|
126
|
+
nil
|
127
|
+
elsif attr.is_a?(Array)
|
128
|
+
attr.map { |a| mod.const_get(a) }
|
129
|
+
else
|
130
|
+
mod.const_get(attr)
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
def model_type_to_module_name(key)
|
135
|
+
key
|
136
|
+
end
|
137
|
+
end
|
138
|
+
end
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
CONFIG_MAPPING_NAMES = {
|
17
|
+
"bert" => "BertConfig",
|
18
|
+
"distilbert" => "DistilBertConfig",
|
19
|
+
"vit" => "ViTConfig"
|
20
|
+
}
|
21
|
+
|
22
|
+
class LazyConfigMapping
|
23
|
+
def initialize(mapping)
|
24
|
+
@mapping = mapping
|
25
|
+
@extra_content = {}
|
26
|
+
@modules = {}
|
27
|
+
end
|
28
|
+
|
29
|
+
def [](key)
|
30
|
+
value = @mapping.fetch(key)
|
31
|
+
module_name = model_type_to_module_name(key)
|
32
|
+
if !@modules.include?(module_name)
|
33
|
+
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
34
|
+
end
|
35
|
+
@modules[module_name].const_get(value)
|
36
|
+
end
|
37
|
+
|
38
|
+
def model_type_to_module_name(key)
|
39
|
+
key
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
CONFIG_MAPPING = LazyConfigMapping.new(CONFIG_MAPPING_NAMES)
|
44
|
+
|
45
|
+
class AutoConfig
|
46
|
+
def self.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
47
|
+
kwargs[:_from_auto] = true
|
48
|
+
kwargs[:name_or_path] = pretrained_model_name_or_path
|
49
|
+
_trust_remote_code = kwargs.delete(:trust_remote_code)
|
50
|
+
_code_revision = kwargs.delete(:code_revision)
|
51
|
+
|
52
|
+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
53
|
+
if config_dict[:model_type]
|
54
|
+
config_class = CONFIG_MAPPING[config_dict[:model_type]]
|
55
|
+
return config_class.from_dict(config_dict, **unused_kwargs)
|
56
|
+
else
|
57
|
+
raise Todo
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
FEATURE_EXTRACTOR_MAPPING_NAMES = {
|
17
|
+
}
|
18
|
+
|
19
|
+
FEATURE_EXTRACTOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
|
20
|
+
end
|
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
IMAGE_PROCESSOR_MAPPING_NAMES = {
|
17
|
+
"vit" => ["ViTImageProcessor"]
|
18
|
+
}
|
19
|
+
|
20
|
+
IMAGE_PROCESSOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
|
21
|
+
|
22
|
+
class AutoImageProcessor
|
23
|
+
def self.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
24
|
+
config = kwargs.delete(:config)
|
25
|
+
use_fast = kwargs.delete(:use_fast)
|
26
|
+
trust_remote_code = kwargs.delete(:trust_remote_code)
|
27
|
+
kwargs[:_from_auto] = true
|
28
|
+
|
29
|
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
30
|
+
image_processor_class = config_dict[:image_processor_type]
|
31
|
+
image_processor_auto_map = nil
|
32
|
+
if (config_dict[:auto_map] || {}).include?("AutoImageProcessor")
|
33
|
+
image_processor_auto_map = config_dict[:auto_map]["AutoImageProcessor"]
|
34
|
+
end
|
35
|
+
|
36
|
+
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
|
37
|
+
# and if so, infer the image processor class from there.
|
38
|
+
if image_processor_class.nil? && image_processor_auto_map.nil?
|
39
|
+
feature_extractor_class = config_dict.delete(:feature_extractor_type)
|
40
|
+
if !feature_extractor_class.nil?
|
41
|
+
image_processor_class = feature_extractor_class.sub("FeatureExtractor", "ImageProcessor")
|
42
|
+
end
|
43
|
+
if (config_dict[:auto_map] || {}).include?("AutoFeatureExtractor")
|
44
|
+
feature_extractor_auto_map = config_dict[:auto_map]["AutoFeatureExtractor"]
|
45
|
+
image_processor_auto_map = feature_extractor_auto_map.sub("FeatureExtractor", "ImageProcessor")
|
46
|
+
end
|
47
|
+
end
|
48
|
+
|
49
|
+
# If we don't find the image processor class in the image processor config, let's try the model config.
|
50
|
+
if image_processor_class.nil? && image_processor_auto_map.nil?
|
51
|
+
if !config.is_a?(PretrainedConfig)
|
52
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
53
|
+
end
|
54
|
+
# It could be in `config.image_processor_type``
|
55
|
+
image_processor_class = config.instance_variable_get(:@image_processor_type)
|
56
|
+
end
|
57
|
+
|
58
|
+
if !image_processor_class.nil?
|
59
|
+
raise Todo
|
60
|
+
end
|
61
|
+
|
62
|
+
has_remote_code = !image_processor_auto_map.nil?
|
63
|
+
has_local_code = !image_processor_class.nil? || IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
|
64
|
+
trust_remote_code = DynamicModuleUtils.resolve_trust_remote_code(
|
65
|
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
66
|
+
)
|
67
|
+
|
68
|
+
if !image_processor_auto_map.nil? && !image_processor_auto_map.is_a?(Array)
|
69
|
+
raise Todo
|
70
|
+
end
|
71
|
+
|
72
|
+
if has_remote_code && trust_remote_code
|
73
|
+
raise Todo
|
74
|
+
elsif !image_processor_class.nil?
|
75
|
+
return image_processor_class.from_dict(config_dict, **kwargs)
|
76
|
+
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
|
77
|
+
elsif IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
|
78
|
+
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[config.class.name.split("::").last]
|
79
|
+
|
80
|
+
image_processor_class_py, image_processor_class_fast = image_processor_tuple
|
81
|
+
|
82
|
+
if !use_fast && !image_processor_class_fast.nil?
|
83
|
+
_warning_fast_image_processor_available(image_processor_class_fast)
|
84
|
+
end
|
85
|
+
|
86
|
+
if image_processor_class_fast && (use_fast || image_processor_class_py.nil?)
|
87
|
+
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
88
|
+
else
|
89
|
+
if !image_processor_class_py.nil?
|
90
|
+
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
91
|
+
else
|
92
|
+
raise ArgumentError,
|
93
|
+
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
raise ArgumentError,
|
99
|
+
"Unrecognized image processor in #{pretrained_model_name_or_path}. Should have a " +
|
100
|
+
"`image_processor_type` key in its #{IMAGE_PROCESSOR_NAME} of #{CONFIG_NAME}, or one of the following " +
|
101
|
+
"`model_type` keys in its #{CONFIG_NAME}: #{IMAGE_PROCESSOR_MAPPING_NAMES.keys.join(", ")}"
|
102
|
+
end
|
103
|
+
end
|
104
|
+
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
MODEL_MAPPING_NAMES = {
|
17
|
+
"bert" => "BertModel",
|
18
|
+
"distilbert" => "DistilBertModel",
|
19
|
+
"vit" => "ViTModel"
|
20
|
+
}
|
21
|
+
|
22
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
23
|
+
"bert" => "BertForMaskedLM"
|
24
|
+
}
|
25
|
+
|
26
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
27
|
+
"distilbert" => "DistilBertForSequenceClassification"
|
28
|
+
}
|
29
|
+
|
30
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
31
|
+
"distilbert" => "DistilBertForQuestionAnswering"
|
32
|
+
}
|
33
|
+
|
34
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
|
35
|
+
"vit" => "ViTForImageClassification"
|
36
|
+
}
|
37
|
+
|
38
|
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
39
|
+
"bert" => "BertForTokenClassification"
|
40
|
+
}
|
41
|
+
|
42
|
+
MODEL_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
43
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
|
44
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
45
|
+
)
|
46
|
+
MODEL_FOR_MASKED_LM_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
47
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
|
48
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
49
|
+
)
|
50
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING = LazyAutoMapping.new(
|
51
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
52
|
+
)
|
53
|
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
|
54
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
55
|
+
)
|
56
|
+
|
57
|
+
class AutoModel < BaseAutoModelClass
|
58
|
+
self._model_mapping = MODEL_MAPPING
|
59
|
+
end
|
60
|
+
|
61
|
+
class AutoModelForMaskedLM < BaseAutoModelClass
|
62
|
+
self._model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
63
|
+
end
|
64
|
+
|
65
|
+
class AutoModelForSequenceClassification < BaseAutoModelClass
|
66
|
+
self._model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
67
|
+
end
|
68
|
+
|
69
|
+
class AutoModelForQuestionAnswering < BaseAutoModelClass
|
70
|
+
self._model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
71
|
+
end
|
72
|
+
|
73
|
+
class AutoModelForTokenClassification < BaseAutoModelClass
|
74
|
+
self._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
75
|
+
end
|
76
|
+
|
77
|
+
class AutoModelForImageClassification < BaseAutoModelClass
|
78
|
+
self._model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
79
|
+
end
|
80
|
+
end
|
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
TOKENIZER_MAPPING_NAMES = {
|
17
|
+
"bert" => ["BertTokenizer", "BertTokenizerFast"],
|
18
|
+
"distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"]
|
19
|
+
}
|
20
|
+
|
21
|
+
TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
22
|
+
|
23
|
+
class AutoTokenizer
|
24
|
+
class << self
|
25
|
+
def from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
26
|
+
config = kwargs.delete(:config)
|
27
|
+
kwargs[:_from_auto] = true
|
28
|
+
|
29
|
+
use_fast = kwargs.delete(:use_fast) { true }
|
30
|
+
tokenizer_type = kwargs.delete(:tokenizer_type) { nil }
|
31
|
+
trust_remote_code = kwargs.delete(:trust_remote_code)
|
32
|
+
|
33
|
+
if !tokenizer_type.nil?
|
34
|
+
raise Todo
|
35
|
+
end
|
36
|
+
|
37
|
+
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
38
|
+
if tokenizer_config.include?("_commit_hash")
|
39
|
+
kwargs[:_commit_hash] = tokenizer_config["_commit_hash"]
|
40
|
+
end
|
41
|
+
config_tokenizer_class = tokenizer_config["tokenizer_class"]
|
42
|
+
_tokenizer_auto_map = nil
|
43
|
+
if tokenizer_config["auto_map"]
|
44
|
+
raise Todo
|
45
|
+
end
|
46
|
+
|
47
|
+
# If that did not work, let's try to use the config.
|
48
|
+
if config_tokenizer_class.nil?
|
49
|
+
if !config.is_a?(PretrainedConfig)
|
50
|
+
config = AutoConfig.from_pretrained(
|
51
|
+
pretrained_model_name_or_path, trust_remote_code: trust_remote_code, **kwargs
|
52
|
+
)
|
53
|
+
config_tokenizer_class = config.tokenizer_class
|
54
|
+
# if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
55
|
+
# tokenizer_auto_map = config.auto_map["AutoTokenizer"]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
if !config_tokenizer_class.nil?
|
60
|
+
tokenizer_class = nil
|
61
|
+
if use_fast && !config_tokenizer_class.end_with?("Fast")
|
62
|
+
tokenizer_class_candidate = "#{config_tokenizer_class}Fast"
|
63
|
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
64
|
+
end
|
65
|
+
if tokenizer_class.nil?
|
66
|
+
tokenizer_class_candidate = config_tokenizer_class
|
67
|
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
68
|
+
end
|
69
|
+
if tokenizer_class.nil?
|
70
|
+
raise ArgumentError, "Tokenizer class #{tokenizer_class_candidate} does not exist or is not currently imported."
|
71
|
+
end
|
72
|
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
73
|
+
end
|
74
|
+
|
75
|
+
model_type = config_class_to_model_type(config.class.name.split("::").last)
|
76
|
+
if !model_type.nil?
|
77
|
+
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config.class.name.split("::").last]
|
78
|
+
if tokenizer_class_fast && (use_fast || tokenizer_class_py.nil?)
|
79
|
+
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
80
|
+
else
|
81
|
+
if !tokenizer_class_py.nil?
|
82
|
+
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
83
|
+
else
|
84
|
+
raise ArgumentError, "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed in order to use this tokenizer."
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
raise ArgumentError, "Unrecognized configuration class #{config.class.name} to build an AutoTokenizer."
|
90
|
+
end
|
91
|
+
|
92
|
+
private
|
93
|
+
|
94
|
+
def tokenizer_class_from_name(class_name)
|
95
|
+
if class_name == "PreTrainedTokenizerFast"
|
96
|
+
return PreTrainedTokenizerFast
|
97
|
+
end
|
98
|
+
|
99
|
+
TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
|
100
|
+
if tokenizers.include?(class_name)
|
101
|
+
cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
|
102
|
+
raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
|
103
|
+
return cls
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
raise Todo
|
108
|
+
end
|
109
|
+
|
110
|
+
def get_tokenizer_config(
|
111
|
+
pretrained_model_name_or_path,
|
112
|
+
cache_dir: nil,
|
113
|
+
force_download: false,
|
114
|
+
resume_download: false,
|
115
|
+
proxies: nil,
|
116
|
+
token: nil,
|
117
|
+
revision: nil,
|
118
|
+
local_files_only: false,
|
119
|
+
subfolder: "",
|
120
|
+
**kwargs
|
121
|
+
)
|
122
|
+
commit_hash = kwargs[:_commit_hash]
|
123
|
+
resolved_config_file = Utils::Hub.cached_file(
|
124
|
+
pretrained_model_name_or_path,
|
125
|
+
TOKENIZER_CONFIG_FILE,
|
126
|
+
cache_dir: cache_dir,
|
127
|
+
force_download: force_download,
|
128
|
+
resume_download: resume_download,
|
129
|
+
proxies: proxies,
|
130
|
+
token: token,
|
131
|
+
revision: revision,
|
132
|
+
local_files_only: local_files_only,
|
133
|
+
subfolder: subfolder,
|
134
|
+
_raise_exceptions_for_gated_repo: false,
|
135
|
+
_raise_exceptions_for_missing_entries: false,
|
136
|
+
_raise_exceptions_for_connection_errors: false,
|
137
|
+
_commit_hash: commit_hash
|
138
|
+
)
|
139
|
+
if resolved_config_file.nil?
|
140
|
+
Transformers.logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
141
|
+
return {}
|
142
|
+
end
|
143
|
+
commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
144
|
+
|
145
|
+
result = JSON.load_file(resolved_config_file)
|
146
|
+
result["_commit_hash"] = commit_hash
|
147
|
+
result
|
148
|
+
end
|
149
|
+
|
150
|
+
def config_class_to_model_type(config)
|
151
|
+
CONFIG_MAPPING_NAMES.each do |key, cls|
|
152
|
+
if cls == config
|
153
|
+
return key
|
154
|
+
end
|
155
|
+
end
|
156
|
+
nil
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
160
|
+
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
module Bert
|
18
|
+
class BertConfig < PretrainedConfig
|
19
|
+
self.model_type = "bert"
|
20
|
+
|
21
|
+
attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
|
22
|
+
:intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
|
23
|
+
:max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
|
24
|
+
:position_embedding_type, :use_cache, :classifier_dropout
|
25
|
+
|
26
|
+
def initialize(
|
27
|
+
vocab_size: 30522,
|
28
|
+
hidden_size: 768,
|
29
|
+
num_hidden_layers: 12,
|
30
|
+
num_attention_heads: 12,
|
31
|
+
intermediate_size: 3072,
|
32
|
+
hidden_act: "gelu",
|
33
|
+
hidden_dropout_prob: 0.1,
|
34
|
+
attention_probs_dropout_prob: 0.1,
|
35
|
+
max_position_embeddings: 512,
|
36
|
+
type_vocab_size: 2,
|
37
|
+
initializer_range: 0.02,
|
38
|
+
layer_norm_eps: 1e-12,
|
39
|
+
pad_token_id: 0,
|
40
|
+
position_embedding_type: "absolute",
|
41
|
+
use_cache: true,
|
42
|
+
classifier_dropout: nil,
|
43
|
+
**kwargs
|
44
|
+
)
|
45
|
+
super(pad_token_id: pad_token_id, **kwargs)
|
46
|
+
|
47
|
+
@vocab_size = vocab_size
|
48
|
+
@hidden_size = hidden_size
|
49
|
+
@num_hidden_layers = num_hidden_layers
|
50
|
+
@num_attention_heads = num_attention_heads
|
51
|
+
@hidden_act = hidden_act
|
52
|
+
@intermediate_size = intermediate_size
|
53
|
+
@hidden_dropout_prob = hidden_dropout_prob
|
54
|
+
@attention_probs_dropout_prob = attention_probs_dropout_prob
|
55
|
+
@max_position_embeddings = max_position_embeddings
|
56
|
+
@type_vocab_size = type_vocab_size
|
57
|
+
@initializer_range = initializer_range
|
58
|
+
@layer_norm_eps = layer_norm_eps
|
59
|
+
@position_embedding_type = position_embedding_type
|
60
|
+
@use_cache = use_cache
|
61
|
+
@classifier_dropout = classifier_dropout
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|