transformers-rb 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,68 @@
1
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ module Transformers
16
+ module XlmRoberta
17
+ class XLMRobertaTokenizerFast < PreTrainedTokenizerFast
18
+ VOCAB_FILES_NAMES = {vocab_file: "sentencepiece.bpe.model", tokenizer_file: "tokenizer.json"}
19
+
20
+ self.vocab_files_names = VOCAB_FILES_NAMES
21
+ self.model_input_names = ["input_ids", "attention_mask"]
22
+ # self.slow_tokenizer_class = XLMRobertaTokenizer
23
+
24
+ def initialize(
25
+ vocab_file: nil,
26
+ tokenizer_file: nil,
27
+ bos_token: "<s>",
28
+ eos_token: "</s>",
29
+ sep_token: "</s>",
30
+ cls_token: "<s>",
31
+ unk_token: "<unk>",
32
+ pad_token: "<pad>",
33
+ mask_token: "<mask>",
34
+ **kwargs
35
+ )
36
+ # Mask token behave like a normal word, i.e. include the space before it
37
+ mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token
38
+
39
+ super(vocab_file, tokenizer_file: tokenizer_file, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, **kwargs)
40
+
41
+ @vocab_file = vocab_file
42
+ end
43
+
44
+ def can_save_slow_tokenizer
45
+ @vocab_file ? File.exist?(@vocab_file) : false
46
+ end
47
+
48
+ def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
49
+ if token_ids_1.nil?
50
+ return [@cls_token_id] + token_ids_0 + [@sep_token_id]
51
+ end
52
+ cls = [@cls_token_id]
53
+ sep = [@sep_token_id]
54
+ cls + token_ids_0 + sep + sep + token_ids_1 + sep
55
+ end
56
+
57
+ def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
58
+ sep = [@sep_token_id]
59
+ cls = [@cls_token_id]
60
+
61
+ if token_ids_1.nil?
62
+ return (cls + token_ids_0 + sep).length * [0]
63
+ end
64
+ (cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0]
65
+ end
66
+ end
67
+ end
68
+ end
@@ -89,6 +89,16 @@ module Transformers
89
89
  },
90
90
  "type" => "text"
91
91
  },
92
+ "reranking" => {
93
+ "impl" => RerankingPipeline,
94
+ "pt" => [AutoModelForSequenceClassification],
95
+ "default" => {
96
+ "model" => {
97
+ "pt" => ["mixedbread-ai/mxbai-rerank-base-v1", "03241da"]
98
+ }
99
+ },
100
+ "type" => "text"
101
+ }
92
102
  }
93
103
 
94
104
  PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
@@ -0,0 +1,33 @@
1
+ module Transformers
2
+ class RerankingPipeline < Pipeline
3
+ def _sanitize_parameters(**kwargs)
4
+ [{}, {}, kwargs]
5
+ end
6
+
7
+ def preprocess(inputs)
8
+ @tokenizer.(
9
+ [inputs[:query]] * inputs[:documents].length,
10
+ text_pair: inputs[:documents],
11
+ return_tensors: @framework
12
+ )
13
+ end
14
+
15
+ def _forward(model_inputs)
16
+ model_outputs = @model.(**model_inputs)
17
+ model_outputs
18
+ end
19
+
20
+ def call(query, documents)
21
+ super({query: query, documents: documents})
22
+ end
23
+
24
+ def postprocess(model_outputs)
25
+ model_outputs[0]
26
+ .sigmoid
27
+ .squeeze
28
+ .to_a
29
+ .map.with_index { |s, i| {index: i, score: s} }
30
+ .sort_by { |v| -v[:score] }
31
+ end
32
+ end
33
+ end
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.2"
2
+ VERSION = "0.1.3"
3
3
  end
data/lib/transformers.rb CHANGED
@@ -61,17 +61,32 @@ require_relative "transformers/models/bert/modeling_bert"
61
61
  require_relative "transformers/models/bert/tokenization_bert"
62
62
  require_relative "transformers/models/bert/tokenization_bert_fast"
63
63
 
64
+ # models deberta-v2
65
+ require_relative "transformers/models/deberta_v2/configuration_deberta_v2"
66
+ require_relative "transformers/models/deberta_v2/modeling_deberta_v2"
67
+ require_relative "transformers/models/deberta_v2/tokenization_deberta_v2_fast"
68
+
64
69
  # models distilbert
65
70
  require_relative "transformers/models/distilbert/configuration_distilbert"
66
71
  require_relative "transformers/models/distilbert/modeling_distilbert"
67
72
  require_relative "transformers/models/distilbert/tokenization_distilbert"
68
73
  require_relative "transformers/models/distilbert/tokenization_distilbert_fast"
69
74
 
75
+ # models mpnet
76
+ require_relative "transformers/models/mpnet/configuration_mpnet"
77
+ require_relative "transformers/models/mpnet/modeling_mpnet"
78
+ require_relative "transformers/models/mpnet/tokenization_mpnet_fast"
79
+
70
80
  # models vit
71
81
  require_relative "transformers/models/vit/configuration_vit"
72
82
  require_relative "transformers/models/vit/image_processing_vit"
73
83
  require_relative "transformers/models/vit/modeling_vit"
74
84
 
85
+ # models xml-roberta
86
+ require_relative "transformers/models/xlm_roberta/configuration_xlm_roberta"
87
+ require_relative "transformers/models/xlm_roberta/modeling_xlm_roberta"
88
+ require_relative "transformers/models/xlm_roberta/tokenization_xlm_roberta_fast"
89
+
75
90
  # pipelines
76
91
  require_relative "transformers/pipelines/base"
77
92
  require_relative "transformers/pipelines/feature_extraction"
@@ -80,6 +95,7 @@ require_relative "transformers/pipelines/image_classification"
80
95
  require_relative "transformers/pipelines/image_feature_extraction"
81
96
  require_relative "transformers/pipelines/pt_utils"
82
97
  require_relative "transformers/pipelines/question_answering"
98
+ require_relative "transformers/pipelines/reranking"
83
99
  require_relative "transformers/pipelines/text_classification"
84
100
  require_relative "transformers/pipelines/token_classification"
85
101
  require_relative "transformers/pipelines/_init"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.1.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-09-10 00:00:00.000000000 Z
11
+ date: 2024-09-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -44,14 +44,14 @@ dependencies:
44
44
  requirements:
45
45
  - - ">="
46
46
  - !ruby/object:Gem::Version
47
- version: 0.5.2
47
+ version: 0.5.3
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - ">="
53
53
  - !ruby/object:Gem::Version
54
- version: 0.5.2
54
+ version: 0.5.3
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: torch-rb
57
57
  requirement: !ruby/object:Gem::Requirement
@@ -104,13 +104,22 @@ files:
104
104
  - lib/transformers/models/bert/modeling_bert.rb
105
105
  - lib/transformers/models/bert/tokenization_bert.rb
106
106
  - lib/transformers/models/bert/tokenization_bert_fast.rb
107
+ - lib/transformers/models/deberta_v2/configuration_deberta_v2.rb
108
+ - lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
109
+ - lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb
107
110
  - lib/transformers/models/distilbert/configuration_distilbert.rb
108
111
  - lib/transformers/models/distilbert/modeling_distilbert.rb
109
112
  - lib/transformers/models/distilbert/tokenization_distilbert.rb
110
113
  - lib/transformers/models/distilbert/tokenization_distilbert_fast.rb
114
+ - lib/transformers/models/mpnet/configuration_mpnet.rb
115
+ - lib/transformers/models/mpnet/modeling_mpnet.rb
116
+ - lib/transformers/models/mpnet/tokenization_mpnet_fast.rb
111
117
  - lib/transformers/models/vit/configuration_vit.rb
112
118
  - lib/transformers/models/vit/image_processing_vit.rb
113
119
  - lib/transformers/models/vit/modeling_vit.rb
120
+ - lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb
121
+ - lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
122
+ - lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb
114
123
  - lib/transformers/pipelines/_init.rb
115
124
  - lib/transformers/pipelines/base.rb
116
125
  - lib/transformers/pipelines/embedding.rb
@@ -119,6 +128,7 @@ files:
119
128
  - lib/transformers/pipelines/image_feature_extraction.rb
120
129
  - lib/transformers/pipelines/pt_utils.rb
121
130
  - lib/transformers/pipelines/question_answering.rb
131
+ - lib/transformers/pipelines/reranking.rb
122
132
  - lib/transformers/pipelines/text_classification.rb
123
133
  - lib/transformers/pipelines/token_classification.rb
124
134
  - lib/transformers/ruby_utils.rb