transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,138 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class BaseAutoModelClass
17
+ extend ClassAttribute
18
+
19
+ class_attribute :_model_mapping
20
+
21
+ class << self
22
+ def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
23
+ config = kwargs.delete(:config)
24
+ trust_remote_code = kwargs.delete(:trust_remote_code)
25
+ hub_kwargs_names = [
26
+ :cache_dir,
27
+ :force_download,
28
+ :local_files_only,
29
+ :proxies,
30
+ :resume_download,
31
+ :revision,
32
+ :subfolder,
33
+ :use_auth_token,
34
+ :token
35
+ ]
36
+ hub_kwargs = hub_kwargs_names.select { |k| kwargs.key?(k) }.to_h { |name| [name, kwargs.delete(name)] }
37
+ code_revision = kwargs.delete(:code_revision)
38
+ commit_hash = kwargs.delete(:_commit_hash)
39
+
40
+ if commit_hash.nil?
41
+ if !config.is_a?(PretrainedConfig)
42
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
43
+ resolved_config_file = Utils::Hub.cached_file(
44
+ pretrained_model_name_or_path,
45
+ CONFIG_NAME,
46
+ _raise_exceptions_for_gated_repo: false,
47
+ _raise_exceptions_for_missing_entries: false,
48
+ _raise_exceptions_for_connection_errors: false,
49
+ **hub_kwargs
50
+ )
51
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
52
+ else
53
+ raise Todo
54
+ end
55
+ end
56
+
57
+ if !config.is_a?(PretrainedConfig)
58
+ config, kwargs =
59
+ AutoConfig.from_pretrained(
60
+ pretrained_model_name_or_path,
61
+ return_unused_kwargs: true,
62
+ trust_remote_code: trust_remote_code,
63
+ code_revision: code_revision,
64
+ _commit_hash: commit_hash,
65
+ **hub_kwargs,
66
+ **kwargs
67
+ )
68
+ end
69
+
70
+ model_class = _get_model_class(config, _model_mapping)
71
+ model_class.from_pretrained(
72
+ pretrained_model_name_or_path, *model_args, config: config, **hub_kwargs, **kwargs
73
+ )
74
+ end
75
+
76
+ private
77
+
78
+ def _get_model_class(config, model_mapping)
79
+ supported_models = model_mapping[config.class.name.split("::").last]
80
+ if !supported_models.is_a?(Array)
81
+ return supported_models
82
+ end
83
+
84
+ raise Todo
85
+ end
86
+ end
87
+ end
88
+
89
+ class LazyAutoMapping
90
+ def initialize(config_mapping, model_mapping)
91
+ @config_mapping = config_mapping
92
+ @reverse_config_mapping = config_mapping.invert
93
+ @model_mapping = model_mapping
94
+ @modules = {}
95
+ end
96
+
97
+ def [](key)
98
+ model_type = @reverse_config_mapping[key]
99
+ if @model_mapping[model_type]
100
+ model_name = @model_mapping[model_type]
101
+ return _load_attr_from_module(model_type, model_name)
102
+ end
103
+
104
+ raise KeyError, key
105
+ end
106
+
107
+ def include?(key)
108
+ self[key]
109
+ true
110
+ rescue KeyError
111
+ false
112
+ end
113
+
114
+ private
115
+
116
+ def _load_attr_from_module(model_type, attr)
117
+ module_name = model_type_to_module_name(model_type)
118
+ if !@modules.include?(module_name)
119
+ @modules[module_name] = Transformers.const_get(module_name.capitalize)
120
+ end
121
+ getattribute_from_module(@modules[module_name], attr)
122
+ end
123
+
124
+ def getattribute_from_module(mod, attr)
125
+ if attr.nil?
126
+ nil
127
+ elsif attr.is_a?(Array)
128
+ attr.map { |a| mod.const_get(a) }
129
+ else
130
+ mod.const_get(attr)
131
+ end
132
+ end
133
+
134
+ def model_type_to_module_name(key)
135
+ key
136
+ end
137
+ end
138
+ end
@@ -0,0 +1,61 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ CONFIG_MAPPING_NAMES = {
17
+ "bert" => "BertConfig",
18
+ "distilbert" => "DistilBertConfig",
19
+ "vit" => "ViTConfig"
20
+ }
21
+
22
+ class LazyConfigMapping
23
+ def initialize(mapping)
24
+ @mapping = mapping
25
+ @extra_content = {}
26
+ @modules = {}
27
+ end
28
+
29
+ def [](key)
30
+ value = @mapping.fetch(key)
31
+ module_name = model_type_to_module_name(key)
32
+ if !@modules.include?(module_name)
33
+ @modules[module_name] = Transformers.const_get(module_name.capitalize)
34
+ end
35
+ @modules[module_name].const_get(value)
36
+ end
37
+
38
+ def model_type_to_module_name(key)
39
+ key
40
+ end
41
+ end
42
+
43
+ CONFIG_MAPPING = LazyConfigMapping.new(CONFIG_MAPPING_NAMES)
44
+
45
+ class AutoConfig
46
+ def self.from_pretrained(pretrained_model_name_or_path, **kwargs)
47
+ kwargs[:_from_auto] = true
48
+ kwargs[:name_or_path] = pretrained_model_name_or_path
49
+ _trust_remote_code = kwargs.delete(:trust_remote_code)
50
+ _code_revision = kwargs.delete(:code_revision)
51
+
52
+ config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
53
+ if config_dict[:model_type]
54
+ config_class = CONFIG_MAPPING[config_dict[:model_type]]
55
+ return config_class.from_dict(config_dict, **unused_kwargs)
56
+ else
57
+ raise Todo
58
+ end
59
+ end
60
+ end
61
+ end
@@ -0,0 +1,20 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ FEATURE_EXTRACTOR_MAPPING_NAMES = {
17
+ }
18
+
19
+ FEATURE_EXTRACTOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
20
+ end
@@ -0,0 +1,104 @@
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ IMAGE_PROCESSOR_MAPPING_NAMES = {
17
+ "vit" => ["ViTImageProcessor"]
18
+ }
19
+
20
+ IMAGE_PROCESSOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
21
+
22
+ class AutoImageProcessor
23
+ def self.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
24
+ config = kwargs.delete(:config)
25
+ use_fast = kwargs.delete(:use_fast)
26
+ trust_remote_code = kwargs.delete(:trust_remote_code)
27
+ kwargs[:_from_auto] = true
28
+
29
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
30
+ image_processor_class = config_dict[:image_processor_type]
31
+ image_processor_auto_map = nil
32
+ if (config_dict[:auto_map] || {}).include?("AutoImageProcessor")
33
+ image_processor_auto_map = config_dict[:auto_map]["AutoImageProcessor"]
34
+ end
35
+
36
+ # If we still don't have the image processor class, check if we're loading from a previous feature extractor config
37
+ # and if so, infer the image processor class from there.
38
+ if image_processor_class.nil? && image_processor_auto_map.nil?
39
+ feature_extractor_class = config_dict.delete(:feature_extractor_type)
40
+ if !feature_extractor_class.nil?
41
+ image_processor_class = feature_extractor_class.sub("FeatureExtractor", "ImageProcessor")
42
+ end
43
+ if (config_dict[:auto_map] || {}).include?("AutoFeatureExtractor")
44
+ feature_extractor_auto_map = config_dict[:auto_map]["AutoFeatureExtractor"]
45
+ image_processor_auto_map = feature_extractor_auto_map.sub("FeatureExtractor", "ImageProcessor")
46
+ end
47
+ end
48
+
49
+ # If we don't find the image processor class in the image processor config, let's try the model config.
50
+ if image_processor_class.nil? && image_processor_auto_map.nil?
51
+ if !config.is_a?(PretrainedConfig)
52
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
53
+ end
54
+ # It could be in `config.image_processor_type``
55
+ image_processor_class = config.instance_variable_get(:@image_processor_type)
56
+ end
57
+
58
+ if !image_processor_class.nil?
59
+ raise Todo
60
+ end
61
+
62
+ has_remote_code = !image_processor_auto_map.nil?
63
+ has_local_code = !image_processor_class.nil? || IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
64
+ trust_remote_code = DynamicModuleUtils.resolve_trust_remote_code(
65
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
66
+ )
67
+
68
+ if !image_processor_auto_map.nil? && !image_processor_auto_map.is_a?(Array)
69
+ raise Todo
70
+ end
71
+
72
+ if has_remote_code && trust_remote_code
73
+ raise Todo
74
+ elsif !image_processor_class.nil?
75
+ return image_processor_class.from_dict(config_dict, **kwargs)
76
+ # Last try: we use the IMAGE_PROCESSOR_MAPPING.
77
+ elsif IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
78
+ image_processor_tuple = IMAGE_PROCESSOR_MAPPING[config.class.name.split("::").last]
79
+
80
+ image_processor_class_py, image_processor_class_fast = image_processor_tuple
81
+
82
+ if !use_fast && !image_processor_class_fast.nil?
83
+ _warning_fast_image_processor_available(image_processor_class_fast)
84
+ end
85
+
86
+ if image_processor_class_fast && (use_fast || image_processor_class_py.nil?)
87
+ return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
88
+ else
89
+ if !image_processor_class_py.nil?
90
+ return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
91
+ else
92
+ raise ArgumentError,
93
+ "This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
94
+ end
95
+ end
96
+ end
97
+
98
+ raise ArgumentError,
99
+ "Unrecognized image processor in #{pretrained_model_name_or_path}. Should have a " +
100
+ "`image_processor_type` key in its #{IMAGE_PROCESSOR_NAME} of #{CONFIG_NAME}, or one of the following " +
101
+ "`model_type` keys in its #{CONFIG_NAME}: #{IMAGE_PROCESSOR_MAPPING_NAMES.keys.join(", ")}"
102
+ end
103
+ end
104
+ end
@@ -0,0 +1,80 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ MODEL_MAPPING_NAMES = {
17
+ "bert" => "BertModel",
18
+ "distilbert" => "DistilBertModel",
19
+ "vit" => "ViTModel"
20
+ }
21
+
22
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
23
+ "bert" => "BertForMaskedLM"
24
+ }
25
+
26
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
27
+ "distilbert" => "DistilBertForSequenceClassification"
28
+ }
29
+
30
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
31
+ "distilbert" => "DistilBertForQuestionAnswering"
32
+ }
33
+
34
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
35
+ "vit" => "ViTForImageClassification"
36
+ }
37
+
38
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
39
+ "bert" => "BertForTokenClassification"
40
+ }
41
+
42
+ MODEL_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
43
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
44
+ CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
45
+ )
46
+ MODEL_FOR_MASKED_LM_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
47
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
48
+ CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
49
+ )
50
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING = LazyAutoMapping.new(
51
+ CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
52
+ )
53
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
54
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
55
+ )
56
+
57
+ class AutoModel < BaseAutoModelClass
58
+ self._model_mapping = MODEL_MAPPING
59
+ end
60
+
61
+ class AutoModelForMaskedLM < BaseAutoModelClass
62
+ self._model_mapping = MODEL_FOR_MASKED_LM_MAPPING
63
+ end
64
+
65
+ class AutoModelForSequenceClassification < BaseAutoModelClass
66
+ self._model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
67
+ end
68
+
69
+ class AutoModelForQuestionAnswering < BaseAutoModelClass
70
+ self._model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
71
+ end
72
+
73
+ class AutoModelForTokenClassification < BaseAutoModelClass
74
+ self._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
75
+ end
76
+
77
+ class AutoModelForImageClassification < BaseAutoModelClass
78
+ self._model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
79
+ end
80
+ end
@@ -0,0 +1,160 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ TOKENIZER_MAPPING_NAMES = {
17
+ "bert" => ["BertTokenizer", "BertTokenizerFast"],
18
+ "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"]
19
+ }
20
+
21
+ TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
22
+
23
+ class AutoTokenizer
24
+ class << self
25
+ def from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
26
+ config = kwargs.delete(:config)
27
+ kwargs[:_from_auto] = true
28
+
29
+ use_fast = kwargs.delete(:use_fast) { true }
30
+ tokenizer_type = kwargs.delete(:tokenizer_type) { nil }
31
+ trust_remote_code = kwargs.delete(:trust_remote_code)
32
+
33
+ if !tokenizer_type.nil?
34
+ raise Todo
35
+ end
36
+
37
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
38
+ if tokenizer_config.include?("_commit_hash")
39
+ kwargs[:_commit_hash] = tokenizer_config["_commit_hash"]
40
+ end
41
+ config_tokenizer_class = tokenizer_config["tokenizer_class"]
42
+ _tokenizer_auto_map = nil
43
+ if tokenizer_config["auto_map"]
44
+ raise Todo
45
+ end
46
+
47
+ # If that did not work, let's try to use the config.
48
+ if config_tokenizer_class.nil?
49
+ if !config.is_a?(PretrainedConfig)
50
+ config = AutoConfig.from_pretrained(
51
+ pretrained_model_name_or_path, trust_remote_code: trust_remote_code, **kwargs
52
+ )
53
+ config_tokenizer_class = config.tokenizer_class
54
+ # if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
55
+ # tokenizer_auto_map = config.auto_map["AutoTokenizer"]
56
+ end
57
+ end
58
+
59
+ if !config_tokenizer_class.nil?
60
+ tokenizer_class = nil
61
+ if use_fast && !config_tokenizer_class.end_with?("Fast")
62
+ tokenizer_class_candidate = "#{config_tokenizer_class}Fast"
63
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
64
+ end
65
+ if tokenizer_class.nil?
66
+ tokenizer_class_candidate = config_tokenizer_class
67
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
68
+ end
69
+ if tokenizer_class.nil?
70
+ raise ArgumentError, "Tokenizer class #{tokenizer_class_candidate} does not exist or is not currently imported."
71
+ end
72
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
73
+ end
74
+
75
+ model_type = config_class_to_model_type(config.class.name.split("::").last)
76
+ if !model_type.nil?
77
+ tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config.class.name.split("::").last]
78
+ if tokenizer_class_fast && (use_fast || tokenizer_class_py.nil?)
79
+ return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
80
+ else
81
+ if !tokenizer_class_py.nil?
82
+ return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
83
+ else
84
+ raise ArgumentError, "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed in order to use this tokenizer."
85
+ end
86
+ end
87
+ end
88
+
89
+ raise ArgumentError, "Unrecognized configuration class #{config.class.name} to build an AutoTokenizer."
90
+ end
91
+
92
+ private
93
+
94
+ def tokenizer_class_from_name(class_name)
95
+ if class_name == "PreTrainedTokenizerFast"
96
+ return PreTrainedTokenizerFast
97
+ end
98
+
99
+ TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
100
+ if tokenizers.include?(class_name)
101
+ cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
102
+ raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
103
+ return cls
104
+ end
105
+ end
106
+
107
+ raise Todo
108
+ end
109
+
110
+ def get_tokenizer_config(
111
+ pretrained_model_name_or_path,
112
+ cache_dir: nil,
113
+ force_download: false,
114
+ resume_download: false,
115
+ proxies: nil,
116
+ token: nil,
117
+ revision: nil,
118
+ local_files_only: false,
119
+ subfolder: "",
120
+ **kwargs
121
+ )
122
+ commit_hash = kwargs[:_commit_hash]
123
+ resolved_config_file = Utils::Hub.cached_file(
124
+ pretrained_model_name_or_path,
125
+ TOKENIZER_CONFIG_FILE,
126
+ cache_dir: cache_dir,
127
+ force_download: force_download,
128
+ resume_download: resume_download,
129
+ proxies: proxies,
130
+ token: token,
131
+ revision: revision,
132
+ local_files_only: local_files_only,
133
+ subfolder: subfolder,
134
+ _raise_exceptions_for_gated_repo: false,
135
+ _raise_exceptions_for_missing_entries: false,
136
+ _raise_exceptions_for_connection_errors: false,
137
+ _commit_hash: commit_hash
138
+ )
139
+ if resolved_config_file.nil?
140
+ Transformers.logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
141
+ return {}
142
+ end
143
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
144
+
145
+ result = JSON.load_file(resolved_config_file)
146
+ result["_commit_hash"] = commit_hash
147
+ result
148
+ end
149
+
150
+ def config_class_to_model_type(config)
151
+ CONFIG_MAPPING_NAMES.each do |key, cls|
152
+ if cls == config
153
+ return key
154
+ end
155
+ end
156
+ nil
157
+ end
158
+ end
159
+ end
160
+ end
@@ -0,0 +1,65 @@
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module Bert
18
+ class BertConfig < PretrainedConfig
19
+ self.model_type = "bert"
20
+
21
+ attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
22
+ :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
23
+ :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
24
+ :position_embedding_type, :use_cache, :classifier_dropout
25
+
26
+ def initialize(
27
+ vocab_size: 30522,
28
+ hidden_size: 768,
29
+ num_hidden_layers: 12,
30
+ num_attention_heads: 12,
31
+ intermediate_size: 3072,
32
+ hidden_act: "gelu",
33
+ hidden_dropout_prob: 0.1,
34
+ attention_probs_dropout_prob: 0.1,
35
+ max_position_embeddings: 512,
36
+ type_vocab_size: 2,
37
+ initializer_range: 0.02,
38
+ layer_norm_eps: 1e-12,
39
+ pad_token_id: 0,
40
+ position_embedding_type: "absolute",
41
+ use_cache: true,
42
+ classifier_dropout: nil,
43
+ **kwargs
44
+ )
45
+ super(pad_token_id: pad_token_id, **kwargs)
46
+
47
+ @vocab_size = vocab_size
48
+ @hidden_size = hidden_size
49
+ @num_hidden_layers = num_hidden_layers
50
+ @num_attention_heads = num_attention_heads
51
+ @hidden_act = hidden_act
52
+ @intermediate_size = intermediate_size
53
+ @hidden_dropout_prob = hidden_dropout_prob
54
+ @attention_probs_dropout_prob = attention_probs_dropout_prob
55
+ @max_position_embeddings = max_position_embeddings
56
+ @type_vocab_size = type_vocab_size
57
+ @initializer_range = initializer_range
58
+ @layer_norm_eps = layer_norm_eps
59
+ @position_embedding_type = position_embedding_type
60
+ @use_cache = use_cache
61
+ @classifier_dropout = classifier_dropout
62
+ end
63
+ end
64
+ end
65
+ end