transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,138 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class BaseAutoModelClass
17
+ extend ClassAttribute
18
+
19
+ class_attribute :_model_mapping
20
+
21
+ class << self
22
+ def from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
23
+ config = kwargs.delete(:config)
24
+ trust_remote_code = kwargs.delete(:trust_remote_code)
25
+ hub_kwargs_names = [
26
+ :cache_dir,
27
+ :force_download,
28
+ :local_files_only,
29
+ :proxies,
30
+ :resume_download,
31
+ :revision,
32
+ :subfolder,
33
+ :use_auth_token,
34
+ :token
35
+ ]
36
+ hub_kwargs = hub_kwargs_names.select { |k| kwargs.key?(k) }.to_h { |name| [name, kwargs.delete(name)] }
37
+ code_revision = kwargs.delete(:code_revision)
38
+ commit_hash = kwargs.delete(:_commit_hash)
39
+
40
+ if commit_hash.nil?
41
+ if !config.is_a?(PretrainedConfig)
42
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
43
+ resolved_config_file = Utils::Hub.cached_file(
44
+ pretrained_model_name_or_path,
45
+ CONFIG_NAME,
46
+ _raise_exceptions_for_gated_repo: false,
47
+ _raise_exceptions_for_missing_entries: false,
48
+ _raise_exceptions_for_connection_errors: false,
49
+ **hub_kwargs
50
+ )
51
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
52
+ else
53
+ raise Todo
54
+ end
55
+ end
56
+
57
+ if !config.is_a?(PretrainedConfig)
58
+ config, kwargs =
59
+ AutoConfig.from_pretrained(
60
+ pretrained_model_name_or_path,
61
+ return_unused_kwargs: true,
62
+ trust_remote_code: trust_remote_code,
63
+ code_revision: code_revision,
64
+ _commit_hash: commit_hash,
65
+ **hub_kwargs,
66
+ **kwargs
67
+ )
68
+ end
69
+
70
+ model_class = _get_model_class(config, _model_mapping)
71
+ model_class.from_pretrained(
72
+ pretrained_model_name_or_path, *model_args, config: config, **hub_kwargs, **kwargs
73
+ )
74
+ end
75
+
76
+ private
77
+
78
+ def _get_model_class(config, model_mapping)
79
+ supported_models = model_mapping[config.class.name.split("::").last]
80
+ if !supported_models.is_a?(Array)
81
+ return supported_models
82
+ end
83
+
84
+ raise Todo
85
+ end
86
+ end
87
+ end
88
+
89
+ class LazyAutoMapping
90
+ def initialize(config_mapping, model_mapping)
91
+ @config_mapping = config_mapping
92
+ @reverse_config_mapping = config_mapping.invert
93
+ @model_mapping = model_mapping
94
+ @modules = {}
95
+ end
96
+
97
+ def [](key)
98
+ model_type = @reverse_config_mapping[key]
99
+ if @model_mapping[model_type]
100
+ model_name = @model_mapping[model_type]
101
+ return _load_attr_from_module(model_type, model_name)
102
+ end
103
+
104
+ raise KeyError, key
105
+ end
106
+
107
+ def include?(key)
108
+ self[key]
109
+ true
110
+ rescue KeyError
111
+ false
112
+ end
113
+
114
+ private
115
+
116
+ def _load_attr_from_module(model_type, attr)
117
+ module_name = model_type_to_module_name(model_type)
118
+ if !@modules.include?(module_name)
119
+ @modules[module_name] = Transformers.const_get(module_name.capitalize)
120
+ end
121
+ getattribute_from_module(@modules[module_name], attr)
122
+ end
123
+
124
+ def getattribute_from_module(mod, attr)
125
+ if attr.nil?
126
+ nil
127
+ elsif attr.is_a?(Array)
128
+ attr.map { |a| mod.const_get(a) }
129
+ else
130
+ mod.const_get(attr)
131
+ end
132
+ end
133
+
134
+ def model_type_to_module_name(key)
135
+ key
136
+ end
137
+ end
138
+ end
@@ -0,0 +1,61 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ CONFIG_MAPPING_NAMES = {
17
+ "bert" => "BertConfig",
18
+ "distilbert" => "DistilBertConfig",
19
+ "vit" => "ViTConfig"
20
+ }
21
+
22
+ class LazyConfigMapping
23
+ def initialize(mapping)
24
+ @mapping = mapping
25
+ @extra_content = {}
26
+ @modules = {}
27
+ end
28
+
29
+ def [](key)
30
+ value = @mapping.fetch(key)
31
+ module_name = model_type_to_module_name(key)
32
+ if !@modules.include?(module_name)
33
+ @modules[module_name] = Transformers.const_get(module_name.capitalize)
34
+ end
35
+ @modules[module_name].const_get(value)
36
+ end
37
+
38
+ def model_type_to_module_name(key)
39
+ key
40
+ end
41
+ end
42
+
43
+ CONFIG_MAPPING = LazyConfigMapping.new(CONFIG_MAPPING_NAMES)
44
+
45
+ class AutoConfig
46
+ def self.from_pretrained(pretrained_model_name_or_path, **kwargs)
47
+ kwargs[:_from_auto] = true
48
+ kwargs[:name_or_path] = pretrained_model_name_or_path
49
+ _trust_remote_code = kwargs.delete(:trust_remote_code)
50
+ _code_revision = kwargs.delete(:code_revision)
51
+
52
+ config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
53
+ if config_dict[:model_type]
54
+ config_class = CONFIG_MAPPING[config_dict[:model_type]]
55
+ return config_class.from_dict(config_dict, **unused_kwargs)
56
+ else
57
+ raise Todo
58
+ end
59
+ end
60
+ end
61
+ end
@@ -0,0 +1,20 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ FEATURE_EXTRACTOR_MAPPING_NAMES = {
17
+ }
18
+
19
+ FEATURE_EXTRACTOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
20
+ end
@@ -0,0 +1,104 @@
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ IMAGE_PROCESSOR_MAPPING_NAMES = {
17
+ "vit" => ["ViTImageProcessor"]
18
+ }
19
+
20
+ IMAGE_PROCESSOR_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
21
+
22
+ class AutoImageProcessor
23
+ def self.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
24
+ config = kwargs.delete(:config)
25
+ use_fast = kwargs.delete(:use_fast)
26
+ trust_remote_code = kwargs.delete(:trust_remote_code)
27
+ kwargs[:_from_auto] = true
28
+
29
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
30
+ image_processor_class = config_dict[:image_processor_type]
31
+ image_processor_auto_map = nil
32
+ if (config_dict[:auto_map] || {}).include?("AutoImageProcessor")
33
+ image_processor_auto_map = config_dict[:auto_map]["AutoImageProcessor"]
34
+ end
35
+
36
+ # If we still don't have the image processor class, check if we're loading from a previous feature extractor config
37
+ # and if so, infer the image processor class from there.
38
+ if image_processor_class.nil? && image_processor_auto_map.nil?
39
+ feature_extractor_class = config_dict.delete(:feature_extractor_type)
40
+ if !feature_extractor_class.nil?
41
+ image_processor_class = feature_extractor_class.sub("FeatureExtractor", "ImageProcessor")
42
+ end
43
+ if (config_dict[:auto_map] || {}).include?("AutoFeatureExtractor")
44
+ feature_extractor_auto_map = config_dict[:auto_map]["AutoFeatureExtractor"]
45
+ image_processor_auto_map = feature_extractor_auto_map.sub("FeatureExtractor", "ImageProcessor")
46
+ end
47
+ end
48
+
49
+ # If we don't find the image processor class in the image processor config, let's try the model config.
50
+ if image_processor_class.nil? && image_processor_auto_map.nil?
51
+ if !config.is_a?(PretrainedConfig)
52
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
53
+ end
54
+ # It could be in `config.image_processor_type``
55
+ image_processor_class = config.instance_variable_get(:@image_processor_type)
56
+ end
57
+
58
+ if !image_processor_class.nil?
59
+ raise Todo
60
+ end
61
+
62
+ has_remote_code = !image_processor_auto_map.nil?
63
+ has_local_code = !image_processor_class.nil? || IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
64
+ trust_remote_code = DynamicModuleUtils.resolve_trust_remote_code(
65
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
66
+ )
67
+
68
+ if !image_processor_auto_map.nil? && !image_processor_auto_map.is_a?(Array)
69
+ raise Todo
70
+ end
71
+
72
+ if has_remote_code && trust_remote_code
73
+ raise Todo
74
+ elsif !image_processor_class.nil?
75
+ return image_processor_class.from_dict(config_dict, **kwargs)
76
+ # Last try: we use the IMAGE_PROCESSOR_MAPPING.
77
+ elsif IMAGE_PROCESSOR_MAPPING.include?(config.class.name.split("::").last)
78
+ image_processor_tuple = IMAGE_PROCESSOR_MAPPING[config.class.name.split("::").last]
79
+
80
+ image_processor_class_py, image_processor_class_fast = image_processor_tuple
81
+
82
+ if !use_fast && !image_processor_class_fast.nil?
83
+ _warning_fast_image_processor_available(image_processor_class_fast)
84
+ end
85
+
86
+ if image_processor_class_fast && (use_fast || image_processor_class_py.nil?)
87
+ return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
88
+ else
89
+ if !image_processor_class_py.nil?
90
+ return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
91
+ else
92
+ raise ArgumentError,
93
+ "This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
94
+ end
95
+ end
96
+ end
97
+
98
+ raise ArgumentError,
99
+ "Unrecognized image processor in #{pretrained_model_name_or_path}. Should have a " +
100
+ "`image_processor_type` key in its #{IMAGE_PROCESSOR_NAME} of #{CONFIG_NAME}, or one of the following " +
101
+ "`model_type` keys in its #{CONFIG_NAME}: #{IMAGE_PROCESSOR_MAPPING_NAMES.keys.join(", ")}"
102
+ end
103
+ end
104
+ end
@@ -0,0 +1,80 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ MODEL_MAPPING_NAMES = {
17
+ "bert" => "BertModel",
18
+ "distilbert" => "DistilBertModel",
19
+ "vit" => "ViTModel"
20
+ }
21
+
22
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
23
+ "bert" => "BertForMaskedLM"
24
+ }
25
+
26
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
27
+ "distilbert" => "DistilBertForSequenceClassification"
28
+ }
29
+
30
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
31
+ "distilbert" => "DistilBertForQuestionAnswering"
32
+ }
33
+
34
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
35
+ "vit" => "ViTForImageClassification"
36
+ }
37
+
38
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
39
+ "bert" => "BertForTokenClassification"
40
+ }
41
+
42
+ MODEL_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
43
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
44
+ CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
45
+ )
46
+ MODEL_FOR_MASKED_LM_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
47
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
48
+ CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
49
+ )
50
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING = LazyAutoMapping.new(
51
+ CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
52
+ )
53
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = LazyAutoMapping.new(
54
+ CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
55
+ )
56
+
57
+ class AutoModel < BaseAutoModelClass
58
+ self._model_mapping = MODEL_MAPPING
59
+ end
60
+
61
+ class AutoModelForMaskedLM < BaseAutoModelClass
62
+ self._model_mapping = MODEL_FOR_MASKED_LM_MAPPING
63
+ end
64
+
65
+ class AutoModelForSequenceClassification < BaseAutoModelClass
66
+ self._model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
67
+ end
68
+
69
+ class AutoModelForQuestionAnswering < BaseAutoModelClass
70
+ self._model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
71
+ end
72
+
73
+ class AutoModelForTokenClassification < BaseAutoModelClass
74
+ self._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
75
+ end
76
+
77
+ class AutoModelForImageClassification < BaseAutoModelClass
78
+ self._model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
79
+ end
80
+ end
@@ -0,0 +1,160 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ TOKENIZER_MAPPING_NAMES = {
17
+ "bert" => ["BertTokenizer", "BertTokenizerFast"],
18
+ "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"]
19
+ }
20
+
21
+ TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
22
+
23
+ class AutoTokenizer
24
+ class << self
25
+ def from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
26
+ config = kwargs.delete(:config)
27
+ kwargs[:_from_auto] = true
28
+
29
+ use_fast = kwargs.delete(:use_fast) { true }
30
+ tokenizer_type = kwargs.delete(:tokenizer_type) { nil }
31
+ trust_remote_code = kwargs.delete(:trust_remote_code)
32
+
33
+ if !tokenizer_type.nil?
34
+ raise Todo
35
+ end
36
+
37
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
38
+ if tokenizer_config.include?("_commit_hash")
39
+ kwargs[:_commit_hash] = tokenizer_config["_commit_hash"]
40
+ end
41
+ config_tokenizer_class = tokenizer_config["tokenizer_class"]
42
+ _tokenizer_auto_map = nil
43
+ if tokenizer_config["auto_map"]
44
+ raise Todo
45
+ end
46
+
47
+ # If that did not work, let's try to use the config.
48
+ if config_tokenizer_class.nil?
49
+ if !config.is_a?(PretrainedConfig)
50
+ config = AutoConfig.from_pretrained(
51
+ pretrained_model_name_or_path, trust_remote_code: trust_remote_code, **kwargs
52
+ )
53
+ config_tokenizer_class = config.tokenizer_class
54
+ # if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
55
+ # tokenizer_auto_map = config.auto_map["AutoTokenizer"]
56
+ end
57
+ end
58
+
59
+ if !config_tokenizer_class.nil?
60
+ tokenizer_class = nil
61
+ if use_fast && !config_tokenizer_class.end_with?("Fast")
62
+ tokenizer_class_candidate = "#{config_tokenizer_class}Fast"
63
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
64
+ end
65
+ if tokenizer_class.nil?
66
+ tokenizer_class_candidate = config_tokenizer_class
67
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
68
+ end
69
+ if tokenizer_class.nil?
70
+ raise ArgumentError, "Tokenizer class #{tokenizer_class_candidate} does not exist or is not currently imported."
71
+ end
72
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
73
+ end
74
+
75
+ model_type = config_class_to_model_type(config.class.name.split("::").last)
76
+ if !model_type.nil?
77
+ tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config.class.name.split("::").last]
78
+ if tokenizer_class_fast && (use_fast || tokenizer_class_py.nil?)
79
+ return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
80
+ else
81
+ if !tokenizer_class_py.nil?
82
+ return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
83
+ else
84
+ raise ArgumentError, "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed in order to use this tokenizer."
85
+ end
86
+ end
87
+ end
88
+
89
+ raise ArgumentError, "Unrecognized configuration class #{config.class.name} to build an AutoTokenizer."
90
+ end
91
+
92
+ private
93
+
94
+ def tokenizer_class_from_name(class_name)
95
+ if class_name == "PreTrainedTokenizerFast"
96
+ return PreTrainedTokenizerFast
97
+ end
98
+
99
+ TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
100
+ if tokenizers.include?(class_name)
101
+ cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
102
+ raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
103
+ return cls
104
+ end
105
+ end
106
+
107
+ raise Todo
108
+ end
109
+
110
+ def get_tokenizer_config(
111
+ pretrained_model_name_or_path,
112
+ cache_dir: nil,
113
+ force_download: false,
114
+ resume_download: false,
115
+ proxies: nil,
116
+ token: nil,
117
+ revision: nil,
118
+ local_files_only: false,
119
+ subfolder: "",
120
+ **kwargs
121
+ )
122
+ commit_hash = kwargs[:_commit_hash]
123
+ resolved_config_file = Utils::Hub.cached_file(
124
+ pretrained_model_name_or_path,
125
+ TOKENIZER_CONFIG_FILE,
126
+ cache_dir: cache_dir,
127
+ force_download: force_download,
128
+ resume_download: resume_download,
129
+ proxies: proxies,
130
+ token: token,
131
+ revision: revision,
132
+ local_files_only: local_files_only,
133
+ subfolder: subfolder,
134
+ _raise_exceptions_for_gated_repo: false,
135
+ _raise_exceptions_for_missing_entries: false,
136
+ _raise_exceptions_for_connection_errors: false,
137
+ _commit_hash: commit_hash
138
+ )
139
+ if resolved_config_file.nil?
140
+ Transformers.logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
141
+ return {}
142
+ end
143
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
144
+
145
+ result = JSON.load_file(resolved_config_file)
146
+ result["_commit_hash"] = commit_hash
147
+ result
148
+ end
149
+
150
+ def config_class_to_model_type(config)
151
+ CONFIG_MAPPING_NAMES.each do |key, cls|
152
+ if cls == config
153
+ return key
154
+ end
155
+ end
156
+ nil
157
+ end
158
+ end
159
+ end
160
+ end
@@ -0,0 +1,65 @@
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module Bert
18
+ class BertConfig < PretrainedConfig
19
+ self.model_type = "bert"
20
+
21
+ attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
22
+ :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
23
+ :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
24
+ :position_embedding_type, :use_cache, :classifier_dropout
25
+
26
+ def initialize(
27
+ vocab_size: 30522,
28
+ hidden_size: 768,
29
+ num_hidden_layers: 12,
30
+ num_attention_heads: 12,
31
+ intermediate_size: 3072,
32
+ hidden_act: "gelu",
33
+ hidden_dropout_prob: 0.1,
34
+ attention_probs_dropout_prob: 0.1,
35
+ max_position_embeddings: 512,
36
+ type_vocab_size: 2,
37
+ initializer_range: 0.02,
38
+ layer_norm_eps: 1e-12,
39
+ pad_token_id: 0,
40
+ position_embedding_type: "absolute",
41
+ use_cache: true,
42
+ classifier_dropout: nil,
43
+ **kwargs
44
+ )
45
+ super(pad_token_id: pad_token_id, **kwargs)
46
+
47
+ @vocab_size = vocab_size
48
+ @hidden_size = hidden_size
49
+ @num_hidden_layers = num_hidden_layers
50
+ @num_attention_heads = num_attention_heads
51
+ @hidden_act = hidden_act
52
+ @intermediate_size = intermediate_size
53
+ @hidden_dropout_prob = hidden_dropout_prob
54
+ @attention_probs_dropout_prob = attention_probs_dropout_prob
55
+ @max_position_embeddings = max_position_embeddings
56
+ @type_vocab_size = type_vocab_size
57
+ @initializer_range = initializer_range
58
+ @layer_norm_eps = layer_norm_eps
59
+ @position_embedding_type = position_embedding_type
60
+ @use_cache = use_cache
61
+ @classifier_dropout = classifier_dropout
62
+ end
63
+ end
64
+ end
65
+ end