transformers-rb 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/README.md +61 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +10 -0
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +14 -4
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
module Mpnet
|
18
|
+
class MPNetTokenizerFast < PreTrainedTokenizerFast
|
19
|
+
VOCAB_FILES_NAMES = {vocab_file: "vocab.txt", tokenizer_file: "tokenizer.json"}
|
20
|
+
|
21
|
+
self.vocab_files_names = VOCAB_FILES_NAMES
|
22
|
+
# self.slow_tokenizer_class = MPNetTokenizer
|
23
|
+
self.model_input_names = ["input_ids", "attention_mask"]
|
24
|
+
|
25
|
+
def initialize(
|
26
|
+
vocab_file: nil,
|
27
|
+
tokenizer_file: nil,
|
28
|
+
do_lower_case: true,
|
29
|
+
bos_token: "<s>",
|
30
|
+
eos_token: "</s>",
|
31
|
+
sep_token: "</s>",
|
32
|
+
cls_token: "<s>",
|
33
|
+
unk_token: "[UNK]",
|
34
|
+
pad_token: "<pad>",
|
35
|
+
mask_token: "<mask>",
|
36
|
+
tokenize_chinese_chars: true,
|
37
|
+
strip_accents: nil,
|
38
|
+
**kwargs
|
39
|
+
)
|
40
|
+
bos_token = bos_token.is_a?(String) ? Tokenizers::AddedToken.new(bos_token, lstrip: false, rstrip: false) : bos_token
|
41
|
+
eos_token = eos_token.is_a?(String) ? Tokenizers::AddedToken.new(eos_token, lstrip: false, rstrip: false) : eos_token
|
42
|
+
sep_token = sep_token.is_a?(String) ? Tokenizers::AddedToken.new(sep_token, lstrip: false, rstrip: false) : sep_token
|
43
|
+
cls_token = cls_token.is_a?(String) ? Tokenizers::AddedToken.new(cls_token, lstrip: false, rstrip: false) : cls_token
|
44
|
+
unk_token = unk_token.is_a?(String) ? Tokenizers::AddedToken.new(unk_token, lstrip: false, rstrip: false) : unk_token
|
45
|
+
pad_token = pad_token.is_a?(String) ? Tokenizers::AddedToken.new(pad_token, lstrip: false, rstrip: false) : pad_token
|
46
|
+
|
47
|
+
# Mask token behave like a normal word, i.e. include the space before it
|
48
|
+
mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token
|
49
|
+
|
50
|
+
super(vocab_file, tokenizer_file: tokenizer_file, do_lower_case: do_lower_case, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, tokenize_chinese_chars: tokenize_chinese_chars, strip_accents: strip_accents, **kwargs)
|
51
|
+
|
52
|
+
# TODO support
|
53
|
+
# pre_tok_state = JSON.parse(backend_tokenizer.normalizer.__getstate__)
|
54
|
+
# if (pre_tok_state["lowercase"] || do_lower_case) != do_lower_case || (pre_tok_state["strip_accents"] || strip_accents) != strip_accents
|
55
|
+
# pre_tok_class = getattr(normalizers, pre_tok_state.delete("type"))
|
56
|
+
# pre_tok_state["lowercase"] = do_lower_case
|
57
|
+
# pre_tok_state["strip_accents"] = strip_accents
|
58
|
+
# @normalizer = pre_tok_class(**pre_tok_state)
|
59
|
+
# end
|
60
|
+
|
61
|
+
@do_lower_case = do_lower_case
|
62
|
+
end
|
63
|
+
|
64
|
+
def mask_token
|
65
|
+
if @mask_token.nil?
|
66
|
+
if @verbose
|
67
|
+
Transformers.logger.error("Using mask_token, but it is not set yet.")
|
68
|
+
end
|
69
|
+
return nil
|
70
|
+
end
|
71
|
+
@mask_token.to_s
|
72
|
+
end
|
73
|
+
|
74
|
+
def mask_token=(value)
|
75
|
+
# Mask token behave like a normal word, i.e. include the space before it
|
76
|
+
# So we set lstrip to True
|
77
|
+
value = value.is_a?(String) ? Tokenizers::AddedToken.new(value, lstrip: true, rstrip: false) : value
|
78
|
+
@mask_token = value
|
79
|
+
end
|
80
|
+
|
81
|
+
def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
|
82
|
+
output = [@bos_token_id] + token_ids_0 + [@eos_token_id]
|
83
|
+
if token_ids_1.nil?
|
84
|
+
return output
|
85
|
+
end
|
86
|
+
|
87
|
+
output + [@eos_token_id] + token_ids_1 + [@eos_token_id]
|
88
|
+
end
|
89
|
+
|
90
|
+
def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
|
91
|
+
sep = [@sep_token_id]
|
92
|
+
cls = [@cls_token_id]
|
93
|
+
|
94
|
+
if token_ids_1.nil?
|
95
|
+
return (cls + token_ids_0 + sep).length * [0]
|
96
|
+
end
|
97
|
+
(cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0]
|
98
|
+
end
|
99
|
+
|
100
|
+
def save_vocabulary(save_directory, filename_prefix: nil)
|
101
|
+
files = @tokenizer.model.save(save_directory, name: filename_prefix)
|
102
|
+
Array(files)
|
103
|
+
end
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
module XlmRoberta
|
18
|
+
class XLMRobertaConfig < PretrainedConfig
|
19
|
+
self.model_type = "xlm-roberta"
|
20
|
+
|
21
|
+
attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
|
22
|
+
:intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
|
23
|
+
:max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
|
24
|
+
:pad_token_id, :bos_token_id, :eos_token_id, :position_embedding_type, :use_cache,
|
25
|
+
:classifier_dropout
|
26
|
+
|
27
|
+
def initialize(
|
28
|
+
vocab_size: 30522,
|
29
|
+
hidden_size: 768,
|
30
|
+
num_hidden_layers: 12,
|
31
|
+
num_attention_heads: 12,
|
32
|
+
intermediate_size: 3072,
|
33
|
+
hidden_act: "gelu",
|
34
|
+
hidden_dropout_prob: 0.1,
|
35
|
+
attention_probs_dropout_prob: 0.1,
|
36
|
+
max_position_embeddings: 512,
|
37
|
+
type_vocab_size: 2,
|
38
|
+
initializer_range: 0.02,
|
39
|
+
layer_norm_eps: 1e-12,
|
40
|
+
pad_token_id: 1,
|
41
|
+
bos_token_id: 0,
|
42
|
+
eos_token_id: 2,
|
43
|
+
position_embedding_type: "absolute",
|
44
|
+
use_cache: true,
|
45
|
+
classifier_dropout: nil,
|
46
|
+
**kwargs
|
47
|
+
)
|
48
|
+
super(pad_token_id: pad_token_id, bos_token_id: bos_token_id, eos_token_id: eos_token_id, **kwargs)
|
49
|
+
|
50
|
+
@vocab_size = vocab_size
|
51
|
+
@hidden_size = hidden_size
|
52
|
+
@num_hidden_layers = num_hidden_layers
|
53
|
+
@num_attention_heads = num_attention_heads
|
54
|
+
@hidden_act = hidden_act
|
55
|
+
@intermediate_size = intermediate_size
|
56
|
+
@hidden_dropout_prob = hidden_dropout_prob
|
57
|
+
@attention_probs_dropout_prob = attention_probs_dropout_prob
|
58
|
+
@max_position_embeddings = max_position_embeddings
|
59
|
+
@type_vocab_size = type_vocab_size
|
60
|
+
@initializer_range = initializer_range
|
61
|
+
@layer_norm_eps = layer_norm_eps
|
62
|
+
@position_embedding_type = position_embedding_type
|
63
|
+
@use_cache = use_cache
|
64
|
+
@classifier_dropout = classifier_dropout
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|