transformers-rb 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/README.md +61 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +10 -0
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +14 -4
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module XlmRoberta
|
17
|
+
class XLMRobertaTokenizerFast < PreTrainedTokenizerFast
|
18
|
+
VOCAB_FILES_NAMES = {vocab_file: "sentencepiece.bpe.model", tokenizer_file: "tokenizer.json"}
|
19
|
+
|
20
|
+
self.vocab_files_names = VOCAB_FILES_NAMES
|
21
|
+
self.model_input_names = ["input_ids", "attention_mask"]
|
22
|
+
# self.slow_tokenizer_class = XLMRobertaTokenizer
|
23
|
+
|
24
|
+
def initialize(
|
25
|
+
vocab_file: nil,
|
26
|
+
tokenizer_file: nil,
|
27
|
+
bos_token: "<s>",
|
28
|
+
eos_token: "</s>",
|
29
|
+
sep_token: "</s>",
|
30
|
+
cls_token: "<s>",
|
31
|
+
unk_token: "<unk>",
|
32
|
+
pad_token: "<pad>",
|
33
|
+
mask_token: "<mask>",
|
34
|
+
**kwargs
|
35
|
+
)
|
36
|
+
# Mask token behave like a normal word, i.e. include the space before it
|
37
|
+
mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token
|
38
|
+
|
39
|
+
super(vocab_file, tokenizer_file: tokenizer_file, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, **kwargs)
|
40
|
+
|
41
|
+
@vocab_file = vocab_file
|
42
|
+
end
|
43
|
+
|
44
|
+
def can_save_slow_tokenizer
|
45
|
+
@vocab_file ? File.exist?(@vocab_file) : false
|
46
|
+
end
|
47
|
+
|
48
|
+
def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
|
49
|
+
if token_ids_1.nil?
|
50
|
+
return [@cls_token_id] + token_ids_0 + [@sep_token_id]
|
51
|
+
end
|
52
|
+
cls = [@cls_token_id]
|
53
|
+
sep = [@sep_token_id]
|
54
|
+
cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
55
|
+
end
|
56
|
+
|
57
|
+
def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
|
58
|
+
sep = [@sep_token_id]
|
59
|
+
cls = [@cls_token_id]
|
60
|
+
|
61
|
+
if token_ids_1.nil?
|
62
|
+
return (cls + token_ids_0 + sep).length * [0]
|
63
|
+
end
|
64
|
+
(cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0]
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -89,6 +89,16 @@ module Transformers
|
|
89
89
|
},
|
90
90
|
"type" => "text"
|
91
91
|
},
|
92
|
+
"reranking" => {
|
93
|
+
"impl" => RerankingPipeline,
|
94
|
+
"pt" => [AutoModelForSequenceClassification],
|
95
|
+
"default" => {
|
96
|
+
"model" => {
|
97
|
+
"pt" => ["mixedbread-ai/mxbai-rerank-base-v1", "03241da"]
|
98
|
+
}
|
99
|
+
},
|
100
|
+
"type" => "text"
|
101
|
+
}
|
92
102
|
}
|
93
103
|
|
94
104
|
PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Transformers
|
2
|
+
class RerankingPipeline < Pipeline
|
3
|
+
def _sanitize_parameters(**kwargs)
|
4
|
+
[{}, {}, kwargs]
|
5
|
+
end
|
6
|
+
|
7
|
+
def preprocess(inputs)
|
8
|
+
@tokenizer.(
|
9
|
+
[inputs[:query]] * inputs[:documents].length,
|
10
|
+
text_pair: inputs[:documents],
|
11
|
+
return_tensors: @framework
|
12
|
+
)
|
13
|
+
end
|
14
|
+
|
15
|
+
def _forward(model_inputs)
|
16
|
+
model_outputs = @model.(**model_inputs)
|
17
|
+
model_outputs
|
18
|
+
end
|
19
|
+
|
20
|
+
def call(query, documents)
|
21
|
+
super({query: query, documents: documents})
|
22
|
+
end
|
23
|
+
|
24
|
+
def postprocess(model_outputs)
|
25
|
+
model_outputs[0]
|
26
|
+
.sigmoid
|
27
|
+
.squeeze
|
28
|
+
.to_a
|
29
|
+
.map.with_index { |s, i| {index: i, score: s} }
|
30
|
+
.sort_by { |v| -v[:score] }
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
data/lib/transformers/version.rb
CHANGED
data/lib/transformers.rb
CHANGED
@@ -61,17 +61,32 @@ require_relative "transformers/models/bert/modeling_bert"
|
|
61
61
|
require_relative "transformers/models/bert/tokenization_bert"
|
62
62
|
require_relative "transformers/models/bert/tokenization_bert_fast"
|
63
63
|
|
64
|
+
# models deberta-v2
|
65
|
+
require_relative "transformers/models/deberta_v2/configuration_deberta_v2"
|
66
|
+
require_relative "transformers/models/deberta_v2/modeling_deberta_v2"
|
67
|
+
require_relative "transformers/models/deberta_v2/tokenization_deberta_v2_fast"
|
68
|
+
|
64
69
|
# models distilbert
|
65
70
|
require_relative "transformers/models/distilbert/configuration_distilbert"
|
66
71
|
require_relative "transformers/models/distilbert/modeling_distilbert"
|
67
72
|
require_relative "transformers/models/distilbert/tokenization_distilbert"
|
68
73
|
require_relative "transformers/models/distilbert/tokenization_distilbert_fast"
|
69
74
|
|
75
|
+
# models mpnet
|
76
|
+
require_relative "transformers/models/mpnet/configuration_mpnet"
|
77
|
+
require_relative "transformers/models/mpnet/modeling_mpnet"
|
78
|
+
require_relative "transformers/models/mpnet/tokenization_mpnet_fast"
|
79
|
+
|
70
80
|
# models vit
|
71
81
|
require_relative "transformers/models/vit/configuration_vit"
|
72
82
|
require_relative "transformers/models/vit/image_processing_vit"
|
73
83
|
require_relative "transformers/models/vit/modeling_vit"
|
74
84
|
|
85
|
+
# models xml-roberta
|
86
|
+
require_relative "transformers/models/xlm_roberta/configuration_xlm_roberta"
|
87
|
+
require_relative "transformers/models/xlm_roberta/modeling_xlm_roberta"
|
88
|
+
require_relative "transformers/models/xlm_roberta/tokenization_xlm_roberta_fast"
|
89
|
+
|
75
90
|
# pipelines
|
76
91
|
require_relative "transformers/pipelines/base"
|
77
92
|
require_relative "transformers/pipelines/feature_extraction"
|
@@ -80,6 +95,7 @@ require_relative "transformers/pipelines/image_classification"
|
|
80
95
|
require_relative "transformers/pipelines/image_feature_extraction"
|
81
96
|
require_relative "transformers/pipelines/pt_utils"
|
82
97
|
require_relative "transformers/pipelines/question_answering"
|
98
|
+
require_relative "transformers/pipelines/reranking"
|
83
99
|
require_relative "transformers/pipelines/text_classification"
|
84
100
|
require_relative "transformers/pipelines/token_classification"
|
85
101
|
require_relative "transformers/pipelines/_init"
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: transformers-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-09-
|
11
|
+
date: 2024-09-17 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -44,14 +44,14 @@ dependencies:
|
|
44
44
|
requirements:
|
45
45
|
- - ">="
|
46
46
|
- !ruby/object:Gem::Version
|
47
|
-
version: 0.5.
|
47
|
+
version: 0.5.3
|
48
48
|
type: :runtime
|
49
49
|
prerelease: false
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
51
51
|
requirements:
|
52
52
|
- - ">="
|
53
53
|
- !ruby/object:Gem::Version
|
54
|
-
version: 0.5.
|
54
|
+
version: 0.5.3
|
55
55
|
- !ruby/object:Gem::Dependency
|
56
56
|
name: torch-rb
|
57
57
|
requirement: !ruby/object:Gem::Requirement
|
@@ -104,13 +104,22 @@ files:
|
|
104
104
|
- lib/transformers/models/bert/modeling_bert.rb
|
105
105
|
- lib/transformers/models/bert/tokenization_bert.rb
|
106
106
|
- lib/transformers/models/bert/tokenization_bert_fast.rb
|
107
|
+
- lib/transformers/models/deberta_v2/configuration_deberta_v2.rb
|
108
|
+
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
|
109
|
+
- lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb
|
107
110
|
- lib/transformers/models/distilbert/configuration_distilbert.rb
|
108
111
|
- lib/transformers/models/distilbert/modeling_distilbert.rb
|
109
112
|
- lib/transformers/models/distilbert/tokenization_distilbert.rb
|
110
113
|
- lib/transformers/models/distilbert/tokenization_distilbert_fast.rb
|
114
|
+
- lib/transformers/models/mpnet/configuration_mpnet.rb
|
115
|
+
- lib/transformers/models/mpnet/modeling_mpnet.rb
|
116
|
+
- lib/transformers/models/mpnet/tokenization_mpnet_fast.rb
|
111
117
|
- lib/transformers/models/vit/configuration_vit.rb
|
112
118
|
- lib/transformers/models/vit/image_processing_vit.rb
|
113
119
|
- lib/transformers/models/vit/modeling_vit.rb
|
120
|
+
- lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb
|
121
|
+
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
|
122
|
+
- lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb
|
114
123
|
- lib/transformers/pipelines/_init.rb
|
115
124
|
- lib/transformers/pipelines/base.rb
|
116
125
|
- lib/transformers/pipelines/embedding.rb
|
@@ -119,6 +128,7 @@ files:
|
|
119
128
|
- lib/transformers/pipelines/image_feature_extraction.rb
|
120
129
|
- lib/transformers/pipelines/pt_utils.rb
|
121
130
|
- lib/transformers/pipelines/question_answering.rb
|
131
|
+
- lib/transformers/pipelines/reranking.rb
|
122
132
|
- lib/transformers/pipelines/text_classification.rb
|
123
133
|
- lib/transformers/pipelines/token_classification.rb
|
124
134
|
- lib/transformers/ruby_utils.rb
|