informers 1.0.2 → 1.0.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4ea317272c5054b01616643e7e0f0b2b2fe0c4a87fe8399350a6b8d0a279c5a1
4
- data.tar.gz: 530f8aaab9a5ca71811a82adca0272e2ca84525bcf1f60f2209c394cbd0f9c2a
3
+ metadata.gz: f5340da0bce9d55a0339fac6b8806f09119df3e89567ecb37a77e1a5921b8fa2
4
+ data.tar.gz: 66a9d275cb2999ad14ba1cfd900bdcbf9fdc3d26ce29387acdd74452bf2050ef
5
5
  SHA512:
6
- metadata.gz: 76059b486e6f6c0b0054450f76813dd4bf12845da6f46e8089585cd1a69be7db86a0acf446cc5a18e48108393403324626f6656d09bdb69083f2651abc0d2448
7
- data.tar.gz: f466f5382edd76a7092dc6ada349a3e58fe7eedcd481726ca765f8ddfb4543b7269dab96c00a93d10b0fd67f800afd70a619cfb15d78dde494b29cc13d21ef1a
6
+ metadata.gz: a4a0c3da3d8a3555a6f2debca8f2939b6536ac76386cdd6c7264890b2d00842d537ecfca352021fa349ff9c4636ba49c189f652a66676746d9ec2a8d97eecc2a
7
+ data.tar.gz: a06aa115b5966fd1b8da7a80d8481d3e61778f31c3bb0da143f329e81ae3f73d4a1d1b2ee01672f4e90742a35d68a23dd5c871c3b68ffad0c16d8e5de480a60f
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 1.0.3 (2024-08-29)
2
+
3
+ - Added `model_output` option
4
+ - Improved `model_file_name` option
5
+
1
6
  ## 1.0.2 (2024-08-28)
2
7
 
3
8
  - Added `embedding` pipeline
data/README.md CHANGED
@@ -30,10 +30,15 @@ Embedding
30
30
  - [intfloat/e5-base-v2](#intfloate5-base-v2)
31
31
  - [nomic-ai/nomic-embed-text-v1](#nomic-ainomic-embed-text-v1)
32
32
  - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
33
+ - [jinaai/jina-embeddings-v2-base-en](#jinaaijina-embeddings-v2-base-en)
34
+ - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+ - [Xenova/all-mpnet-base-v2](#xenovaall-mpnet-base-v2)
33
36
 
34
- Reranking (experimental)
37
+ Reranking
35
38
 
36
39
  - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
40
+ - [jinaai/jina-reranker-v1-turbo-en](#jinaaijina-reranker-v1-turbo-en)
41
+ - [BAAI/bge-reranker-base](#baaibge-reranker-base)
37
42
 
38
43
  ### sentence-transformers/all-MiniLM-L6-v2
39
44
 
@@ -72,18 +77,16 @@ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
72
77
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
73
78
 
74
79
  ```ruby
75
- def transform_query(query)
76
- "Represent this sentence for searching relevant passages: #{query}"
77
- end
80
+ query_prefix = "Represent this sentence for searching relevant passages: "
78
81
 
79
- docs = [
80
- transform_query("puppy"),
82
+ input = [
81
83
  "The dog is barking",
82
- "The cat is purring"
84
+ "The cat is purring",
85
+ query_prefix + "puppy"
83
86
  ]
84
87
 
85
88
  model = Informers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
86
- embeddings = model.(docs)
89
+ embeddings = model.(input)
87
90
  ```
88
91
 
89
92
  ### Supabase/gte-small
@@ -102,9 +105,12 @@ embeddings = model.(sentences)
102
105
  [Docs](https://huggingface.co/intfloat/e5-base-v2)
103
106
 
104
107
  ```ruby
108
+ doc_prefix = "passage: "
109
+ query_prefix = "query: "
110
+
105
111
  input = [
106
- "passage: Ruby is a programming language created by Matz",
107
- "query: Ruby creator"
112
+ doc_prefix + "Ruby is a programming language created by Matz",
113
+ query_prefix + "Ruby creator"
108
114
  ]
109
115
 
110
116
  model = Informers.pipeline("embedding", "intfloat/e5-base-v2")
@@ -116,9 +122,13 @@ embeddings = model.(input)
116
122
  [Docs](https://huggingface.co/nomic-ai/nomic-embed-text-v1)
117
123
 
118
124
  ```ruby
125
+ doc_prefix = "search_document: "
126
+ query_prefix = "search_query: "
127
+
119
128
  input = [
120
- "search_document: The dog is barking",
121
- "search_query: puppy"
129
+ doc_prefix + "The dog is barking",
130
+ doc_prefix + "The cat is purring",
131
+ query_prefix + "puppy"
122
132
  ]
123
133
 
124
134
  model = Informers.pipeline("embedding", "nomic-ai/nomic-embed-text-v1")
@@ -130,20 +140,57 @@ embeddings = model.(input)
130
140
  [Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)
131
141
 
132
142
  ```ruby
133
- def transform_query(query)
134
- "Represent this sentence for searching relevant passages: #{query}"
135
- end
143
+ query_prefix = "Represent this sentence for searching relevant passages: "
136
144
 
137
145
  input = [
138
- transform_query("puppy"),
139
146
  "The dog is barking",
140
- "The cat is purring"
147
+ "The cat is purring",
148
+ query_prefix + "puppy"
141
149
  ]
142
150
 
143
151
  model = Informers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
144
152
  embeddings = model.(input)
145
153
  ```
146
154
 
155
+ ### jinaai/jina-embeddings-v2-base-en
156
+
157
+ [Docs](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)
158
+
159
+ ```ruby
160
+ sentences = ["How is the weather today?", "What is the current weather like today?"]
161
+
162
+ model = Informers.pipeline("embedding", "jinaai/jina-embeddings-v2-base-en", model_file_name: "../model")
163
+ embeddings = model.(sentences)
164
+ ```
165
+
166
+ ### Snowflake/snowflake-arctic-embed-m-v1.5
167
+
168
+ [Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)
169
+
170
+ ```ruby
171
+ query_prefix = "Represent this sentence for searching relevant passages: "
172
+
173
+ input = [
174
+ "The dog is barking",
175
+ "The cat is purring",
176
+ query_prefix + "puppy"
177
+ ]
178
+
179
+ model = Informers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
180
+ embeddings = model.(input, model_output: "sentence_embedding", pooling: "none")
181
+ ```
182
+
183
+ ### Xenova/all-mpnet-base-v2
184
+
185
+ [Docs](https://huggingface.co/Xenova/all-mpnet-base-v2)
186
+
187
+ ```ruby
188
+ sentences = ["This is an example sentence", "Each sentence is converted"]
189
+
190
+ model = Informers.pipeline("embedding", "Xenova/all-mpnet-base-v2")
191
+ embeddings = model.(sentences)
192
+ ```
193
+
147
194
  ### mixedbread-ai/mxbai-rerank-base-v1
148
195
 
149
196
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
@@ -156,6 +203,30 @@ model = Informers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
156
203
  result = model.(query, docs)
157
204
  ```
158
205
 
206
+ ### jinaai/jina-reranker-v1-turbo-en
207
+
208
+ [Docs](https://huggingface.co/jinaai/jina-reranker-v1-turbo-en)
209
+
210
+ ```ruby
211
+ query = "How many people live in London?"
212
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
213
+
214
+ model = Informers.pipeline("reranking", "jinaai/jina-reranker-v1-turbo-en")
215
+ result = model.(query, docs)
216
+ ```
217
+
218
+ ### BAAI/bge-reranker-base
219
+
220
+ [Docs](https://huggingface.co/BAAI/bge-reranker-base)
221
+
222
+ ```ruby
223
+ query = "How many people live in London?"
224
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
225
+
226
+ model = Informers.pipeline("reranking", "BAAI/bge-reranker-base")
227
+ result = model.(query, docs)
228
+ ```
229
+
159
230
  ### Other
160
231
 
161
232
  You can use the feature extraction pipeline directly.
@@ -165,7 +236,7 @@ model = Informers.pipeline("feature-extraction", "Xenova/all-MiniLM-L6-v2", quan
165
236
  embeddings = model.(sentences, pooling: "mean", normalize: true)
166
237
  ```
167
238
 
168
- The model files must include `onnx/model.onnx` or `onnx/model_quantized.onnx` ([example](https://huggingface.co/Xenova/all-MiniLM-L6-v2/tree/main/onnx)).
239
+ The model must include a `.onnx` file ([example](https://huggingface.co/Xenova/all-MiniLM-L6-v2/tree/main/onnx)). If the file is not at `onnx/model.onnx` or `onnx/model_quantized.onnx`, use the `model_file_name` option to specify the location.
169
240
 
170
241
  ## Pipelines
171
242
 
@@ -176,7 +247,7 @@ embed = Informers.pipeline("embedding")
176
247
  embed.("We are very happy to show you the 🤗 Transformers library.")
177
248
  ```
178
249
 
179
- Reranking (experimental)
250
+ Reranking
180
251
 
181
252
  ```ruby
182
253
  rerank = Informers.pipeline("reranking")
@@ -6,19 +6,14 @@ module Informers
6
6
  end
7
7
 
8
8
  def embed(texts)
9
- is_batched = texts.is_a?(Array)
10
- texts = [texts] unless is_batched
11
-
12
9
  case @model_id
13
10
  when "sentence-transformers/all-MiniLM-L6-v2", "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1", "Supabase/gte-small"
14
- output = @model.(texts)
11
+ @model.(texts)
15
12
  when "mixedbread-ai/mxbai-embed-large-v1"
16
- output = @model.(texts, pooling: "cls", normalize: false)
13
+ @model.(texts, pooling: "cls", normalize: false)
17
14
  else
18
15
  raise Error, "Use the embedding pipeline for this model: #{@model_id}"
19
16
  end
20
-
21
- is_batched ? output : output[0]
22
17
  end
23
18
  end
24
19
  end
@@ -135,7 +135,15 @@ module Informers
135
135
  end
136
136
 
137
137
  def self.construct_session(pretrained_model_name_or_path, file_name, **options)
138
- model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
138
+ prefix = "onnx/"
139
+ if file_name.start_with?("../")
140
+ prefix = ""
141
+ file_name = file_name[3..]
142
+ elsif file_name.start_with?("/")
143
+ prefix = ""
144
+ file_name = file_name[1..]
145
+ end
146
+ model_file_name = "#{prefix}#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
139
147
  path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
140
148
 
141
149
  OnnxRuntime::InferenceSession.new(path)
@@ -229,16 +237,37 @@ module Informers
229
237
  end
230
238
  end
231
239
 
240
+ class MPNetPreTrainedModel < PreTrainedModel
241
+ end
242
+
243
+ class MPNetModel < MPNetPreTrainedModel
244
+ end
245
+
246
+ class XLMRobertaPreTrainedModel < PreTrainedModel
247
+ end
248
+
249
+ class XLMRobertaModel < XLMRobertaPreTrainedModel
250
+ end
251
+
252
+ class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel
253
+ def call(model_inputs)
254
+ SequenceClassifierOutput.new(*super(model_inputs))
255
+ end
256
+ end
257
+
232
258
  MODEL_MAPPING_NAMES_ENCODER_ONLY = {
233
259
  "bert" => ["BertModel", BertModel],
234
260
  "nomic_bert" => ["NomicBertModel", NomicBertModel],
235
261
  "deberta-v2" => ["DebertaV2Model", DebertaV2Model],
236
- "distilbert" => ["DistilBertModel", DistilBertModel]
262
+ "mpnet" => ["MPNetModel", MPNetModel],
263
+ "distilbert" => ["DistilBertModel", DistilBertModel],
264
+ "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel]
237
265
  }
238
266
 
239
267
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
240
268
  "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
241
- "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
269
+ "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
270
+ "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]
242
271
  }
243
272
 
244
273
  MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
@@ -249,7 +249,8 @@ module Informers
249
249
  pooling: "none",
250
250
  normalize: false,
251
251
  quantize: false,
252
- precision: "binary"
252
+ precision: "binary",
253
+ model_output: nil
253
254
  )
254
255
  # Run tokenization
255
256
  model_inputs = @tokenizer.(texts,
@@ -258,8 +259,10 @@ module Informers
258
259
  )
259
260
  model_options = {}
260
261
 
261
- # optimization for sentence-transformers/all-MiniLM-L6-v2
262
- if @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize
262
+ if !model_output.nil?
263
+ model_options[:output_names] = Array(model_output)
264
+ elsif @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize
265
+ # optimization for sentence-transformers/all-MiniLM-L6-v2
263
266
  model_options[:output_names] = ["sentence_embedding"]
264
267
  pooling = "none"
265
268
  normalize = false
@@ -271,7 +274,9 @@ module Informers
271
274
  # TODO improve
272
275
  result =
273
276
  if outputs.is_a?(Array)
274
- raise Error, "unexpected outputs" if outputs.size != 1
277
+ # TODO show returned instead of all
278
+ output_names = @model.instance_variable_get(:@session).outputs.map { |v| v[:name] }
279
+ raise Error, "unexpected outputs: #{output_names}" if outputs.size != 1
275
280
  outputs[0]
276
281
  else
277
282
  outputs.logits
@@ -285,6 +290,7 @@ module Informers
285
290
  when "cls"
286
291
  result = result.map(&:first)
287
292
  else
293
+ # TODO raise ArgumentError in 2.0
288
294
  raise Error, "Pooling method '#{pooling}' not supported."
289
295
  end
290
296
 
@@ -304,9 +310,10 @@ module Informers
304
310
  def call(
305
311
  texts,
306
312
  pooling: "mean",
307
- normalize: true
313
+ normalize: true,
314
+ model_output: nil
308
315
  )
309
- super(texts, pooling:, normalize:)
316
+ super(texts, pooling:, normalize:, model_output:)
310
317
  end
311
318
  end
312
319
 
@@ -91,11 +91,23 @@ module Informers
91
91
  class DistilBertTokenizer < PreTrainedTokenizer
92
92
  end
93
93
 
94
+ class RobertaTokenizer < PreTrainedTokenizer
95
+ end
96
+
97
+ class XLMRobertaTokenizer < PreTrainedTokenizer
98
+ end
99
+
100
+ class MPNetTokenizer < PreTrainedTokenizer
101
+ end
102
+
94
103
  class AutoTokenizer
95
104
  TOKENIZER_CLASS_MAPPING = {
96
105
  "BertTokenizer" => BertTokenizer,
97
106
  "DebertaV2Tokenizer" => DebertaV2Tokenizer,
98
- "DistilBertTokenizer" => DistilBertTokenizer
107
+ "DistilBertTokenizer" => DistilBertTokenizer,
108
+ "RobertaTokenizer" => RobertaTokenizer,
109
+ "XLMRobertaTokenizer" => XLMRobertaTokenizer,
110
+ "MPNetTokenizer" => MPNetTokenizer
99
111
  }
100
112
 
101
113
  def self.from_pretrained(
@@ -1,3 +1,3 @@
1
1
  module Informers
2
- VERSION = "1.0.2"
2
+ VERSION = "1.0.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: informers
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.2
4
+ version: 1.0.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-08-28 00:00:00.000000000 Z
11
+ date: 2024-08-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: onnxruntime