transformers-rb 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,68 @@
1
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License
14
+
15
+ module Transformers
16
+ module XlmRoberta
17
+ class XLMRobertaTokenizerFast < PreTrainedTokenizerFast
18
+ VOCAB_FILES_NAMES = {vocab_file: "sentencepiece.bpe.model", tokenizer_file: "tokenizer.json"}
19
+
20
+ self.vocab_files_names = VOCAB_FILES_NAMES
21
+ self.model_input_names = ["input_ids", "attention_mask"]
22
+ # self.slow_tokenizer_class = XLMRobertaTokenizer
23
+
24
+ def initialize(
25
+ vocab_file: nil,
26
+ tokenizer_file: nil,
27
+ bos_token: "<s>",
28
+ eos_token: "</s>",
29
+ sep_token: "</s>",
30
+ cls_token: "<s>",
31
+ unk_token: "<unk>",
32
+ pad_token: "<pad>",
33
+ mask_token: "<mask>",
34
+ **kwargs
35
+ )
36
+ # Mask token behave like a normal word, i.e. include the space before it
37
+ mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token
38
+
39
+ super(vocab_file, tokenizer_file: tokenizer_file, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, **kwargs)
40
+
41
+ @vocab_file = vocab_file
42
+ end
43
+
44
+ def can_save_slow_tokenizer
45
+ @vocab_file ? File.exist?(@vocab_file) : false
46
+ end
47
+
48
+ def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
49
+ if token_ids_1.nil?
50
+ return [@cls_token_id] + token_ids_0 + [@sep_token_id]
51
+ end
52
+ cls = [@cls_token_id]
53
+ sep = [@sep_token_id]
54
+ cls + token_ids_0 + sep + sep + token_ids_1 + sep
55
+ end
56
+
57
+ def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
58
+ sep = [@sep_token_id]
59
+ cls = [@cls_token_id]
60
+
61
+ if token_ids_1.nil?
62
+ return (cls + token_ids_0 + sep).length * [0]
63
+ end
64
+ (cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0]
65
+ end
66
+ end
67
+ end
68
+ end
@@ -89,6 +89,16 @@ module Transformers
89
89
  },
90
90
  "type" => "text"
91
91
  },
92
+ "reranking" => {
93
+ "impl" => RerankingPipeline,
94
+ "pt" => [AutoModelForSequenceClassification],
95
+ "default" => {
96
+ "model" => {
97
+ "pt" => ["mixedbread-ai/mxbai-rerank-base-v1", "03241da"]
98
+ }
99
+ },
100
+ "type" => "text"
101
+ }
92
102
  }
93
103
 
94
104
  PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
@@ -0,0 +1,33 @@
1
+ module Transformers
2
+ class RerankingPipeline < Pipeline
3
+ def _sanitize_parameters(**kwargs)
4
+ [{}, {}, kwargs]
5
+ end
6
+
7
+ def preprocess(inputs)
8
+ @tokenizer.(
9
+ [inputs[:query]] * inputs[:documents].length,
10
+ text_pair: inputs[:documents],
11
+ return_tensors: @framework
12
+ )
13
+ end
14
+
15
+ def _forward(model_inputs)
16
+ model_outputs = @model.(**model_inputs)
17
+ model_outputs
18
+ end
19
+
20
+ def call(query, documents)
21
+ super({query: query, documents: documents})
22
+ end
23
+
24
+ def postprocess(model_outputs)
25
+ model_outputs[0]
26
+ .sigmoid
27
+ .squeeze
28
+ .to_a
29
+ .map.with_index { |s, i| {index: i, score: s} }
30
+ .sort_by { |v| -v[:score] }
31
+ end
32
+ end
33
+ end
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.2"
2
+ VERSION = "0.1.3"
3
3
  end
data/lib/transformers.rb CHANGED
@@ -61,17 +61,32 @@ require_relative "transformers/models/bert/modeling_bert"
61
61
  require_relative "transformers/models/bert/tokenization_bert"
62
62
  require_relative "transformers/models/bert/tokenization_bert_fast"
63
63
 
64
+ # models deberta-v2
65
+ require_relative "transformers/models/deberta_v2/configuration_deberta_v2"
66
+ require_relative "transformers/models/deberta_v2/modeling_deberta_v2"
67
+ require_relative "transformers/models/deberta_v2/tokenization_deberta_v2_fast"
68
+
64
69
  # models distilbert
65
70
  require_relative "transformers/models/distilbert/configuration_distilbert"
66
71
  require_relative "transformers/models/distilbert/modeling_distilbert"
67
72
  require_relative "transformers/models/distilbert/tokenization_distilbert"
68
73
  require_relative "transformers/models/distilbert/tokenization_distilbert_fast"
69
74
 
75
+ # models mpnet
76
+ require_relative "transformers/models/mpnet/configuration_mpnet"
77
+ require_relative "transformers/models/mpnet/modeling_mpnet"
78
+ require_relative "transformers/models/mpnet/tokenization_mpnet_fast"
79
+
70
80
  # models vit
71
81
  require_relative "transformers/models/vit/configuration_vit"
72
82
  require_relative "transformers/models/vit/image_processing_vit"
73
83
  require_relative "transformers/models/vit/modeling_vit"
74
84
 
85
+ # models xml-roberta
86
+ require_relative "transformers/models/xlm_roberta/configuration_xlm_roberta"
87
+ require_relative "transformers/models/xlm_roberta/modeling_xlm_roberta"
88
+ require_relative "transformers/models/xlm_roberta/tokenization_xlm_roberta_fast"
89
+
75
90
  # pipelines
76
91
  require_relative "transformers/pipelines/base"
77
92
  require_relative "transformers/pipelines/feature_extraction"
@@ -80,6 +95,7 @@ require_relative "transformers/pipelines/image_classification"
80
95
  require_relative "transformers/pipelines/image_feature_extraction"
81
96
  require_relative "transformers/pipelines/pt_utils"
82
97
  require_relative "transformers/pipelines/question_answering"
98
+ require_relative "transformers/pipelines/reranking"
83
99
  require_relative "transformers/pipelines/text_classification"
84
100
  require_relative "transformers/pipelines/token_classification"
85
101
  require_relative "transformers/pipelines/_init"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.1.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-09-10 00:00:00.000000000 Z
11
+ date: 2024-09-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -44,14 +44,14 @@ dependencies:
44
44
  requirements:
45
45
  - - ">="
46
46
  - !ruby/object:Gem::Version
47
- version: 0.5.2
47
+ version: 0.5.3
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - ">="
53
53
  - !ruby/object:Gem::Version
54
- version: 0.5.2
54
+ version: 0.5.3
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: torch-rb
57
57
  requirement: !ruby/object:Gem::Requirement
@@ -104,13 +104,22 @@ files:
104
104
  - lib/transformers/models/bert/modeling_bert.rb
105
105
  - lib/transformers/models/bert/tokenization_bert.rb
106
106
  - lib/transformers/models/bert/tokenization_bert_fast.rb
107
+ - lib/transformers/models/deberta_v2/configuration_deberta_v2.rb
108
+ - lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
109
+ - lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb
107
110
  - lib/transformers/models/distilbert/configuration_distilbert.rb
108
111
  - lib/transformers/models/distilbert/modeling_distilbert.rb
109
112
  - lib/transformers/models/distilbert/tokenization_distilbert.rb
110
113
  - lib/transformers/models/distilbert/tokenization_distilbert_fast.rb
114
+ - lib/transformers/models/mpnet/configuration_mpnet.rb
115
+ - lib/transformers/models/mpnet/modeling_mpnet.rb
116
+ - lib/transformers/models/mpnet/tokenization_mpnet_fast.rb
111
117
  - lib/transformers/models/vit/configuration_vit.rb
112
118
  - lib/transformers/models/vit/image_processing_vit.rb
113
119
  - lib/transformers/models/vit/modeling_vit.rb
120
+ - lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb
121
+ - lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
122
+ - lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb
114
123
  - lib/transformers/pipelines/_init.rb
115
124
  - lib/transformers/pipelines/base.rb
116
125
  - lib/transformers/pipelines/embedding.rb
@@ -119,6 +128,7 @@ files:
119
128
  - lib/transformers/pipelines/image_feature_extraction.rb
120
129
  - lib/transformers/pipelines/pt_utils.rb
121
130
  - lib/transformers/pipelines/question_answering.rb
131
+ - lib/transformers/pipelines/reranking.rb
122
132
  - lib/transformers/pipelines/text_classification.rb
123
133
  - lib/transformers/pipelines/token_classification.rb
124
134
  - lib/transformers/ruby_utils.rb