transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,138 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BaseAutoModelClass
|
17
|
+
extend ClassAttribute
|
18
|
+
|
19
|
+
class_attribute :_model_mapping
|
20
|
+
|
21
|
+
class << self
|
22
|
+
def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
23
|
+
config = kwargs.delete(:config)
|
24
|
+
trust_remote_code = kwargs.delete(:trust_remote_code)
|
25
|
+
hub_kwargs_names = [
|
26
|
+
:cache_dir,
|
27
|
+
:force_download,
|
28
|
+
:local_files_only,
|
29
|
+
:proxies,
|
30
|
+
:resume_download,
|
31
|
+
:revision,
|
32
|
+
:subfolder,
|
33
|
+
:use_auth_token,
|
34
|
+
:token
|
35
|
+
]
|
36
|
+
hub_kwargs = hub_kwargs_names.select { |k| kwargs.key?(k) }.to_h { |name| [name, kwargs.delete(name)] }
|
37
|
+
code_revision = kwargs.delete(:code_revision)
|
38
|
+
commit_hash = kwargs.delete(:_commit_hash)
|
39
|
+
|
40
|
+
if commit_hash.nil?
|
41
|
+
if !config.is_a?(PretrainedConfig)
|
42
|
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
43
|
+
resolved_config_file = Utils::Hub.cached_file(
|
44
|
+
pretrained_model_name_or_path,
|
45
|
+
CONFIG_NAME,
|
46
|
+
_raise_exceptions_for_gated_repo: false,
|
47
|
+
_raise_exceptions_for_missing_entries: false,
|
48
|
+
_raise_exceptions_for_connection_errors: false,
|
49
|
+
**hub_kwargs
|
50
|
+
)
|
51
|
+
commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
52
|
+
else
|
53
|
+
raise Todo
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
if !config.is_a?(PretrainedConfig)
|
58
|
+
config, kwargs =
|
59
|
+
AutoConfig.from_pretrained(
|
60
|
+
pretrained_model_name_or_path,
|
61
|
+
return_unused_kwargs: true,
|
62
|
+
trust_remote_code: trust_remote_code,
|
63
|
+
code_revision: code_revision,
|
64
|
+
_commit_hash: commit_hash,
|
65
|
+
**hub_kwargs,
|
66
|
+
**kwargs
|
67
|
+
)
|
68
|
+
end
|
69
|
+
|
70
|
+
model_class = _get_model_class(config, _model_mapping)
|
71
|
+
model_class.from_pretrained(
|
72
|
+
pretrained_model_name_or_path, *model_args, config: config, **hub_kwargs, **kwargs
|
73
|
+
)
|
74
|
+
end
|
75
|
+
|
76
|
+
private
|
77
|
+
|
78
|
+
def _get_model_class(config, model_mapping)
|
79
|
+
supported_models = model_mapping[config.class.name.split("::").last]
|
80
|
+
if !supported_models.is_a?(Array)
|
81
|
+
return supported_models
|
82
|
+
end
|
83
|
+
|
84
|
+
raise Todo
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
class LazyAutoMapping
|
90
|
+
def initialize(config_mapping, model_mapping)
|
91
|
+
@config_mapping = config_mapping
|
92
|
+
@reverse_config_mapping = config_mapping.invert
|
93
|
+
@model_mapping = model_mapping
|
94
|
+
@modules = {}
|
95
|
+
end
|
96
|
+
|
97
|
+
def [](key)
|
98
|
+
model_type = @reverse_config_mapping[key]
|
99
|
+
if @model_mapping[model_type]
|
100
|
+
model_name = @model_mapping[model_type]
|
101
|
+
return _load_attr_from_module(model_type, model_name)
|
102
|
+
end
|
103
|
+
|
104
|
+
raise KeyError, key
|
105
|
+
end
|
106
|
+
|
107
|
+
def include?(key)
|
108
|
+
self[key]
|
109
|
+
true
|
110
|
+
rescue KeyError
|
111
|
+
false
|
112
|
+
end
|
113
|
+
|
114
|
+
private
|
115
|
+
|
116
|
+
def _load_attr_from_module(model_type, attr)
|
117
|
+
module_name = model_type_to_module_name(model_type)
|
118
|
+
if !@modules.include?(module_name)
|
119
|
+
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
120
|
+
end
|
121
|
+
getattribute_from_module(@modules[module_name], attr)
|
122
|
+
end
|
123
|
+
|
124
|
+
def getattribute_from_module(mod, attr)
|
125
|
+
if attr.nil?
|
126
|
+
nil
|
127
|
+
elsif attr.is_a?(Array)
|
128
|
+
attr.map { |a| mod.const_get(a) }
|
129
|
+
else
|
130
|
+
mod.const_get(attr)
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
def model_type_to_module_name(key)
|
135
|
+
key
|
136
|
+
end
|
137
|
+
end
|
138
|
+
end
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
CONFIG_MAPPING_NAMES = {
|
17
|
+
"bert" => "BertConfig",
|
18
|
+
"distilbert" => "DistilBertConfig",
|
19
|
+
"vit" => "ViTConfig"
|
20
|
+
}
|
21
|
+
|
22
|
+
class LazyConfigMapping
|
23
|
+
def initialize(mapping)
|
24
|
+
@mapping = mapping
|
25
|
+
@extra_content = {}
|
26
|
+
@modules = {}
|
27
|
+
end
|
28
|
+
|
29
|
+
def [](key)
|
30
|
+
value = @mapping.fetch(key)
|
31
|
+
module_name = model_type_to_module_name(key)
|
32
|
+
if !@modules.include?(module_name)
|
33
|
+
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
34
|
+
end
|
35
|
+
@modules[module_name].const_get(value)
|
36
|
+
end
|
37
|
+
|
38
|
+
def model_type_to_module_name(key)
|
39
|
+
key
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
CONFIG_MAPPING = LazyConfigMapping.new(CONFIG_MAPPING_NAMES)
|
44
|
+
|
45
|
+
class AutoConfig
|
46
|
+
def self.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
47
|
+
kwargs[:_from_auto] = true
|
48
|
+
kwargs[:name_or_path] = pretrained_model_name_or_path
|
49
|
+
_trust_remote_code = kwargs.delete(:trust_remote_code)
|
50
|
+
_code_revision = kwargs.delete(:code_revision)
|
51
|
+
|
52
|
+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
53
|
+
if config_dict[:model_type]
|
54
|
+
config_class = CONFIG_MAPPING[config_dict[:model_type]]
|
55
|
+
return config_class.from_dict(config_dict, **unused_kwargs)
|
56
|
+
else
|
57
|
+
raise Todo
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
FEATURE_EXTRACTOR_MAPPING_NAMES = {
|
17
|
+
}
|
18
|
+
|
19
|
+
FEATURE_EXTRACTOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
|
20
|
+
end
|
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
IMAGE_PROCESSOR_MAPPING_NAMES = {
|
17
|
+
"vit" => ["ViTImageProcessor"]
|
18
|
+
}
|
19
|
+
|
20
|
+
IMAGE_PROCESSOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
|
21
|
+
|
22
|
+
class AutoImageProcessor
|
23
|
+
def self.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
24
|
+
config = kwargs.delete(:config)
|
25
|
+
use_fast = kwargs.delete(:use_fast)
|
26
|
+
trust_remote_code = kwargs.delete(:trust_remote_code)
|
27
|
+
kwargs[:_from_auto] = true
|
28
|
+
|
29
|
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
30
|
+
image_processor_class = config_dict[:image_processor_type]
|
31
|
+
image_processor_auto_map = nil
|
32
|
+
if (config_dict[:auto_map] || {}).include?("AutoImageProcessor")
|
33
|
+
image_processor_auto_map = config_dict[:auto_map]["AutoImageProcessor"]
|
34
|
+
end
|
35
|
+
|
36
|
+
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
|
37
|
+
# and if so, infer the image processor class from there.
|
38
|
+
if image_processor_class.nil? && image_processor_auto_map.nil?
|
39
|
+
feature_extractor_class = config_dict.delete(:feature_extractor_type)
|
40
|
+
if !feature_extractor_class.nil?
|
41
|
+
image_processor_class = feature_extractor_class.sub("FeatureExtractor", "ImageProcessor")
|
42
|
+
end
|
43
|
+
if (config_dict[:auto_map] || {}).include?("AutoFeatureExtractor")
|
44
|
+
feature_extractor_auto_map = config_dict[:auto_map]["AutoFeatureExtractor"]
|
45
|
+
image_processor_auto_map = feature_extractor_auto_map.sub("FeatureExtractor", "ImageProcessor")
|
46
|
+
end
|
47
|
+
end
|
48
|
+
|
49
|
+
# If we don't find the image processor class in the image processor config, let's try the model config.
|
50
|
+
if image_processor_class.nil? && image_processor_auto_map.nil?
|
51
|
+
if !config.is_a?(PretrainedConfig)
|
52
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
53
|
+
end
|
54
|
+
# It could be in `config.image_processor_type``
|
55
|
+
image_processor_class = config.instance_variable_get(:@image_processor_type)
|
56
|
+
end
|
57
|
+
|
58
|
+
if !image_processor_class.nil?
|
59
|
+
raise Todo
|
60
|
+
end
|
61
|
+
|
62
|
+
has_remote_code = !image_processor_auto_map.nil?
|
63
|
+
has_local_code = !image_processor_class.nil? || IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
|
64
|
+
trust_remote_code = DynamicModuleUtils.resolve_trust_remote_code(
|
65
|
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
66
|
+
)
|
67
|
+
|
68
|
+
if !image_processor_auto_map.nil? && !image_processor_auto_map.is_a?(Array)
|
69
|
+
raise Todo
|
70
|
+
end
|
71
|
+
|
72
|
+
if has_remote_code && trust_remote_code
|
73
|
+
raise Todo
|
74
|
+
elsif !image_processor_class.nil?
|
75
|
+
return image_processor_class.from_dict(config_dict, **kwargs)
|
76
|
+
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
|
77
|
+
elsif IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
|
78
|
+
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[config.class.name.split("::").last]
|
79
|
+
|
80
|
+
image_processor_class_py, image_processor_class_fast = image_processor_tuple
|
81
|
+
|
82
|
+
if !use_fast && !image_processor_class_fast.nil?
|
83
|
+
_warning_fast_image_processor_available(image_processor_class_fast)
|
84
|
+
end
|
85
|
+
|
86
|
+
if image_processor_class_fast && (use_fast || image_processor_class_py.nil?)
|
87
|
+
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
88
|
+
else
|
89
|
+
if !image_processor_class_py.nil?
|
90
|
+
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
91
|
+
else
|
92
|
+
raise ArgumentError,
|
93
|
+
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
|
94
|
+
end
|
95
|
+
end
|
96
|
+
end
|
97
|
+
|
98
|
+
raise ArgumentError,
|
99
|
+
"Unrecognized image processor in #{pretrained_model_name_or_path}. Should have a " +
|
100
|
+
"`image_processor_type` key in its #{IMAGE_PROCESSOR_NAME} of #{CONFIG_NAME}, or one of the following " +
|
101
|
+
"`model_type` keys in its #{CONFIG_NAME}: #{IMAGE_PROCESSOR_MAPPING_NAMES.keys.join(", ")}"
|
102
|
+
end
|
103
|
+
end
|
104
|
+
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
MODEL_MAPPING_NAMES = {
|
17
|
+
"bert" => "BertModel",
|
18
|
+
"distilbert" => "DistilBertModel",
|
19
|
+
"vit" => "ViTModel"
|
20
|
+
}
|
21
|
+
|
22
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
23
|
+
"bert" => "BertForMaskedLM"
|
24
|
+
}
|
25
|
+
|
26
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
27
|
+
"distilbert" => "DistilBertForSequenceClassification"
|
28
|
+
}
|
29
|
+
|
30
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
31
|
+
"distilbert" => "DistilBertForQuestionAnswering"
|
32
|
+
}
|
33
|
+
|
34
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
|
35
|
+
"vit" => "ViTForImageClassification"
|
36
|
+
}
|
37
|
+
|
38
|
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
39
|
+
"bert" => "BertForTokenClassification"
|
40
|
+
}
|
41
|
+
|
42
|
+
MODEL_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
43
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
|
44
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
45
|
+
)
|
46
|
+
MODEL_FOR_MASKED_LM_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
47
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
|
48
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
49
|
+
)
|
50
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING = LazyAutoMapping.new(
|
51
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
52
|
+
)
|
53
|
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
|
54
|
+
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
55
|
+
)
|
56
|
+
|
57
|
+
class AutoModel < BaseAutoModelClass
|
58
|
+
self._model_mapping = MODEL_MAPPING
|
59
|
+
end
|
60
|
+
|
61
|
+
class AutoModelForMaskedLM < BaseAutoModelClass
|
62
|
+
self._model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
63
|
+
end
|
64
|
+
|
65
|
+
class AutoModelForSequenceClassification < BaseAutoModelClass
|
66
|
+
self._model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
67
|
+
end
|
68
|
+
|
69
|
+
class AutoModelForQuestionAnswering < BaseAutoModelClass
|
70
|
+
self._model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
71
|
+
end
|
72
|
+
|
73
|
+
class AutoModelForTokenClassification < BaseAutoModelClass
|
74
|
+
self._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
75
|
+
end
|
76
|
+
|
77
|
+
class AutoModelForImageClassification < BaseAutoModelClass
|
78
|
+
self._model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
79
|
+
end
|
80
|
+
end
|
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
TOKENIZER_MAPPING_NAMES = {
|
17
|
+
"bert" => ["BertTokenizer", "BertTokenizerFast"],
|
18
|
+
"distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"]
|
19
|
+
}
|
20
|
+
|
21
|
+
TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
22
|
+
|
23
|
+
class AutoTokenizer
|
24
|
+
class << self
|
25
|
+
def from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
26
|
+
config = kwargs.delete(:config)
|
27
|
+
kwargs[:_from_auto] = true
|
28
|
+
|
29
|
+
use_fast = kwargs.delete(:use_fast) { true }
|
30
|
+
tokenizer_type = kwargs.delete(:tokenizer_type) { nil }
|
31
|
+
trust_remote_code = kwargs.delete(:trust_remote_code)
|
32
|
+
|
33
|
+
if !tokenizer_type.nil?
|
34
|
+
raise Todo
|
35
|
+
end
|
36
|
+
|
37
|
+
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
38
|
+
if tokenizer_config.include?("_commit_hash")
|
39
|
+
kwargs[:_commit_hash] = tokenizer_config["_commit_hash"]
|
40
|
+
end
|
41
|
+
config_tokenizer_class = tokenizer_config["tokenizer_class"]
|
42
|
+
_tokenizer_auto_map = nil
|
43
|
+
if tokenizer_config["auto_map"]
|
44
|
+
raise Todo
|
45
|
+
end
|
46
|
+
|
47
|
+
# If that did not work, let's try to use the config.
|
48
|
+
if config_tokenizer_class.nil?
|
49
|
+
if !config.is_a?(PretrainedConfig)
|
50
|
+
config = AutoConfig.from_pretrained(
|
51
|
+
pretrained_model_name_or_path, trust_remote_code: trust_remote_code, **kwargs
|
52
|
+
)
|
53
|
+
config_tokenizer_class = config.tokenizer_class
|
54
|
+
# if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
55
|
+
# tokenizer_auto_map = config.auto_map["AutoTokenizer"]
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
if !config_tokenizer_class.nil?
|
60
|
+
tokenizer_class = nil
|
61
|
+
if use_fast && !config_tokenizer_class.end_with?("Fast")
|
62
|
+
tokenizer_class_candidate = "#{config_tokenizer_class}Fast"
|
63
|
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
64
|
+
end
|
65
|
+
if tokenizer_class.nil?
|
66
|
+
tokenizer_class_candidate = config_tokenizer_class
|
67
|
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
68
|
+
end
|
69
|
+
if tokenizer_class.nil?
|
70
|
+
raise ArgumentError, "Tokenizer class #{tokenizer_class_candidate} does not exist or is not currently imported."
|
71
|
+
end
|
72
|
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
73
|
+
end
|
74
|
+
|
75
|
+
model_type = config_class_to_model_type(config.class.name.split("::").last)
|
76
|
+
if !model_type.nil?
|
77
|
+
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config.class.name.split("::").last]
|
78
|
+
if tokenizer_class_fast && (use_fast || tokenizer_class_py.nil?)
|
79
|
+
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
80
|
+
else
|
81
|
+
if !tokenizer_class_py.nil?
|
82
|
+
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
83
|
+
else
|
84
|
+
raise ArgumentError, "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed in order to use this tokenizer."
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
raise ArgumentError, "Unrecognized configuration class #{config.class.name} to build an AutoTokenizer."
|
90
|
+
end
|
91
|
+
|
92
|
+
private
|
93
|
+
|
94
|
+
def tokenizer_class_from_name(class_name)
|
95
|
+
if class_name == "PreTrainedTokenizerFast"
|
96
|
+
return PreTrainedTokenizerFast
|
97
|
+
end
|
98
|
+
|
99
|
+
TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
|
100
|
+
if tokenizers.include?(class_name)
|
101
|
+
cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
|
102
|
+
raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
|
103
|
+
return cls
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
raise Todo
|
108
|
+
end
|
109
|
+
|
110
|
+
def get_tokenizer_config(
|
111
|
+
pretrained_model_name_or_path,
|
112
|
+
cache_dir: nil,
|
113
|
+
force_download: false,
|
114
|
+
resume_download: false,
|
115
|
+
proxies: nil,
|
116
|
+
token: nil,
|
117
|
+
revision: nil,
|
118
|
+
local_files_only: false,
|
119
|
+
subfolder: "",
|
120
|
+
**kwargs
|
121
|
+
)
|
122
|
+
commit_hash = kwargs[:_commit_hash]
|
123
|
+
resolved_config_file = Utils::Hub.cached_file(
|
124
|
+
pretrained_model_name_or_path,
|
125
|
+
TOKENIZER_CONFIG_FILE,
|
126
|
+
cache_dir: cache_dir,
|
127
|
+
force_download: force_download,
|
128
|
+
resume_download: resume_download,
|
129
|
+
proxies: proxies,
|
130
|
+
token: token,
|
131
|
+
revision: revision,
|
132
|
+
local_files_only: local_files_only,
|
133
|
+
subfolder: subfolder,
|
134
|
+
_raise_exceptions_for_gated_repo: false,
|
135
|
+
_raise_exceptions_for_missing_entries: false,
|
136
|
+
_raise_exceptions_for_connection_errors: false,
|
137
|
+
_commit_hash: commit_hash
|
138
|
+
)
|
139
|
+
if resolved_config_file.nil?
|
140
|
+
Transformers.logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
141
|
+
return {}
|
142
|
+
end
|
143
|
+
commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
144
|
+
|
145
|
+
result = JSON.load_file(resolved_config_file)
|
146
|
+
result["_commit_hash"] = commit_hash
|
147
|
+
result
|
148
|
+
end
|
149
|
+
|
150
|
+
def config_class_to_model_type(config)
|
151
|
+
CONFIG_MAPPING_NAMES.each do |key, cls|
|
152
|
+
if cls == config
|
153
|
+
return key
|
154
|
+
end
|
155
|
+
end
|
156
|
+
nil
|
157
|
+
end
|
158
|
+
end
|
159
|
+
end
|
160
|
+
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
module Bert
|
18
|
+
class BertConfig < PretrainedConfig
|
19
|
+
self.model_type = "bert"
|
20
|
+
|
21
|
+
attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
|
22
|
+
:intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
|
23
|
+
:max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
|
24
|
+
:position_embedding_type, :use_cache, :classifier_dropout
|
25
|
+
|
26
|
+
def initialize(
|
27
|
+
vocab_size: 30522,
|
28
|
+
hidden_size: 768,
|
29
|
+
num_hidden_layers: 12,
|
30
|
+
num_attention_heads: 12,
|
31
|
+
intermediate_size: 3072,
|
32
|
+
hidden_act: "gelu",
|
33
|
+
hidden_dropout_prob: 0.1,
|
34
|
+
attention_probs_dropout_prob: 0.1,
|
35
|
+
max_position_embeddings: 512,
|
36
|
+
type_vocab_size: 2,
|
37
|
+
initializer_range: 0.02,
|
38
|
+
layer_norm_eps: 1e-12,
|
39
|
+
pad_token_id: 0,
|
40
|
+
position_embedding_type: "absolute",
|
41
|
+
use_cache: true,
|
42
|
+
classifier_dropout: nil,
|
43
|
+
**kwargs
|
44
|
+
)
|
45
|
+
super(pad_token_id: pad_token_id, **kwargs)
|
46
|
+
|
47
|
+
@vocab_size = vocab_size
|
48
|
+
@hidden_size = hidden_size
|
49
|
+
@num_hidden_layers = num_hidden_layers
|
50
|
+
@num_attention_heads = num_attention_heads
|
51
|
+
@hidden_act = hidden_act
|
52
|
+
@intermediate_size = intermediate_size
|
53
|
+
@hidden_dropout_prob = hidden_dropout_prob
|
54
|
+
@attention_probs_dropout_prob = attention_probs_dropout_prob
|
55
|
+
@max_position_embeddings = max_position_embeddings
|
56
|
+
@type_vocab_size = type_vocab_size
|
57
|
+
@initializer_range = initializer_range
|
58
|
+
@layer_norm_eps = layer_norm_eps
|
59
|
+
@position_embedding_type = position_embedding_type
|
60
|
+
@use_cache = use_cache
|
61
|
+
@classifier_dropout = classifier_dropout
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|