informers 1.0.2 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4ea317272c5054b01616643e7e0f0b2b2fe0c4a87fe8399350a6b8d0a279c5a1
4
- data.tar.gz: 530f8aaab9a5ca71811a82adca0272e2ca84525bcf1f60f2209c394cbd0f9c2a
3
+ metadata.gz: f5340da0bce9d55a0339fac6b8806f09119df3e89567ecb37a77e1a5921b8fa2
4
+ data.tar.gz: 66a9d275cb2999ad14ba1cfd900bdcbf9fdc3d26ce29387acdd74452bf2050ef
5
5
  SHA512:
6
- metadata.gz: 76059b486e6f6c0b0054450f76813dd4bf12845da6f46e8089585cd1a69be7db86a0acf446cc5a18e48108393403324626f6656d09bdb69083f2651abc0d2448
7
- data.tar.gz: f466f5382edd76a7092dc6ada349a3e58fe7eedcd481726ca765f8ddfb4543b7269dab96c00a93d10b0fd67f800afd70a619cfb15d78dde494b29cc13d21ef1a
6
+ metadata.gz: a4a0c3da3d8a3555a6f2debca8f2939b6536ac76386cdd6c7264890b2d00842d537ecfca352021fa349ff9c4636ba49c189f652a66676746d9ec2a8d97eecc2a
7
+ data.tar.gz: a06aa115b5966fd1b8da7a80d8481d3e61778f31c3bb0da143f329e81ae3f73d4a1d1b2ee01672f4e90742a35d68a23dd5c871c3b68ffad0c16d8e5de480a60f
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 1.0.3 (2024-08-29)
2
+
3
+ - Added `model_output` option
4
+ - Improved `model_file_name` option
5
+
1
6
  ## 1.0.2 (2024-08-28)
2
7
 
3
8
  - Added `embedding` pipeline
data/README.md CHANGED
@@ -30,10 +30,15 @@ Embedding
30
30
  - [intfloat/e5-base-v2](#intfloate5-base-v2)
31
31
  - [nomic-ai/nomic-embed-text-v1](#nomic-ainomic-embed-text-v1)
32
32
  - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
33
+ - [jinaai/jina-embeddings-v2-base-en](#jinaaijina-embeddings-v2-base-en)
34
+ - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+ - [Xenova/all-mpnet-base-v2](#xenovaall-mpnet-base-v2)
33
36
 
34
- Reranking (experimental)
37
+ Reranking
35
38
 
36
39
  - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
40
+ - [jinaai/jina-reranker-v1-turbo-en](#jinaaijina-reranker-v1-turbo-en)
41
+ - [BAAI/bge-reranker-base](#baaibge-reranker-base)
37
42
 
38
43
  ### sentence-transformers/all-MiniLM-L6-v2
39
44
 
@@ -72,18 +77,16 @@ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
72
77
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
73
78
 
74
79
  ```ruby
75
- def transform_query(query)
76
- "Represent this sentence for searching relevant passages: #{query}"
77
- end
80
+ query_prefix = "Represent this sentence for searching relevant passages: "
78
81
 
79
- docs = [
80
- transform_query("puppy"),
82
+ input = [
81
83
  "The dog is barking",
82
- "The cat is purring"
84
+ "The cat is purring",
85
+ query_prefix + "puppy"
83
86
  ]
84
87
 
85
88
  model = Informers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
86
- embeddings = model.(docs)
89
+ embeddings = model.(input)
87
90
  ```
88
91
 
89
92
  ### Supabase/gte-small
@@ -102,9 +105,12 @@ embeddings = model.(sentences)
102
105
  [Docs](https://huggingface.co/intfloat/e5-base-v2)
103
106
 
104
107
  ```ruby
108
+ doc_prefix = "passage: "
109
+ query_prefix = "query: "
110
+
105
111
  input = [
106
- "passage: Ruby is a programming language created by Matz",
107
- "query: Ruby creator"
112
+ doc_prefix + "Ruby is a programming language created by Matz",
113
+ query_prefix + "Ruby creator"
108
114
  ]
109
115
 
110
116
  model = Informers.pipeline("embedding", "intfloat/e5-base-v2")
@@ -116,9 +122,13 @@ embeddings = model.(input)
116
122
  [Docs](https://huggingface.co/nomic-ai/nomic-embed-text-v1)
117
123
 
118
124
  ```ruby
125
+ doc_prefix = "search_document: "
126
+ query_prefix = "search_query: "
127
+
119
128
  input = [
120
- "search_document: The dog is barking",
121
- "search_query: puppy"
129
+ doc_prefix + "The dog is barking",
130
+ doc_prefix + "The cat is purring",
131
+ query_prefix + "puppy"
122
132
  ]
123
133
 
124
134
  model = Informers.pipeline("embedding", "nomic-ai/nomic-embed-text-v1")
@@ -130,20 +140,57 @@ embeddings = model.(input)
130
140
  [Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)
131
141
 
132
142
  ```ruby
133
- def transform_query(query)
134
- "Represent this sentence for searching relevant passages: #{query}"
135
- end
143
+ query_prefix = "Represent this sentence for searching relevant passages: "
136
144
 
137
145
  input = [
138
- transform_query("puppy"),
139
146
  "The dog is barking",
140
- "The cat is purring"
147
+ "The cat is purring",
148
+ query_prefix + "puppy"
141
149
  ]
142
150
 
143
151
  model = Informers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
144
152
  embeddings = model.(input)
145
153
  ```
146
154
 
155
+ ### jinaai/jina-embeddings-v2-base-en
156
+
157
+ [Docs](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)
158
+
159
+ ```ruby
160
+ sentences = ["How is the weather today?", "What is the current weather like today?"]
161
+
162
+ model = Informers.pipeline("embedding", "jinaai/jina-embeddings-v2-base-en", model_file_name: "../model")
163
+ embeddings = model.(sentences)
164
+ ```
165
+
166
+ ### Snowflake/snowflake-arctic-embed-m-v1.5
167
+
168
+ [Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)
169
+
170
+ ```ruby
171
+ query_prefix = "Represent this sentence for searching relevant passages: "
172
+
173
+ input = [
174
+ "The dog is barking",
175
+ "The cat is purring",
176
+ query_prefix + "puppy"
177
+ ]
178
+
179
+ model = Informers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
180
+ embeddings = model.(input, model_output: "sentence_embedding", pooling: "none")
181
+ ```
182
+
183
+ ### Xenova/all-mpnet-base-v2
184
+
185
+ [Docs](https://huggingface.co/Xenova/all-mpnet-base-v2)
186
+
187
+ ```ruby
188
+ sentences = ["This is an example sentence", "Each sentence is converted"]
189
+
190
+ model = Informers.pipeline("embedding", "Xenova/all-mpnet-base-v2")
191
+ embeddings = model.(sentences)
192
+ ```
193
+
147
194
  ### mixedbread-ai/mxbai-rerank-base-v1
148
195
 
149
196
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
@@ -156,6 +203,30 @@ model = Informers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
156
203
  result = model.(query, docs)
157
204
  ```
158
205
 
206
+ ### jinaai/jina-reranker-v1-turbo-en
207
+
208
+ [Docs](https://huggingface.co/jinaai/jina-reranker-v1-turbo-en)
209
+
210
+ ```ruby
211
+ query = "How many people live in London?"
212
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
213
+
214
+ model = Informers.pipeline("reranking", "jinaai/jina-reranker-v1-turbo-en")
215
+ result = model.(query, docs)
216
+ ```
217
+
218
+ ### BAAI/bge-reranker-base
219
+
220
+ [Docs](https://huggingface.co/BAAI/bge-reranker-base)
221
+
222
+ ```ruby
223
+ query = "How many people live in London?"
224
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
225
+
226
+ model = Informers.pipeline("reranking", "BAAI/bge-reranker-base")
227
+ result = model.(query, docs)
228
+ ```
229
+
159
230
  ### Other
160
231
 
161
232
  You can use the feature extraction pipeline directly.
@@ -165,7 +236,7 @@ model = Informers.pipeline("feature-extraction", "Xenova/all-MiniLM-L6-v2", quan
165
236
  embeddings = model.(sentences, pooling: "mean", normalize: true)
166
237
  ```
167
238
 
168
- The model files must include `onnx/model.onnx` or `onnx/model_quantized.onnx` ([example](https://huggingface.co/Xenova/all-MiniLM-L6-v2/tree/main/onnx)).
239
+ The model must include a `.onnx` file ([example](https://huggingface.co/Xenova/all-MiniLM-L6-v2/tree/main/onnx)). If the file is not at `onnx/model.onnx` or `onnx/model_quantized.onnx`, use the `model_file_name` option to specify the location.
169
240
 
170
241
  ## Pipelines
171
242
 
@@ -176,7 +247,7 @@ embed = Informers.pipeline("embedding")
176
247
  embed.("We are very happy to show you the 🤗 Transformers library.")
177
248
  ```
178
249
 
179
- Reranking (experimental)
250
+ Reranking
180
251
 
181
252
  ```ruby
182
253
  rerank = Informers.pipeline("reranking")
@@ -6,19 +6,14 @@ module Informers
6
6
  end
7
7
 
8
8
  def embed(texts)
9
- is_batched = texts.is_a?(Array)
10
- texts = [texts] unless is_batched
11
-
12
9
  case @model_id
13
10
  when "sentence-transformers/all-MiniLM-L6-v2", "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1", "Supabase/gte-small"
14
- output = @model.(texts)
11
+ @model.(texts)
15
12
  when "mixedbread-ai/mxbai-embed-large-v1"
16
- output = @model.(texts, pooling: "cls", normalize: false)
13
+ @model.(texts, pooling: "cls", normalize: false)
17
14
  else
18
15
  raise Error, "Use the embedding pipeline for this model: #{@model_id}"
19
16
  end
20
-
21
- is_batched ? output : output[0]
22
17
  end
23
18
  end
24
19
  end
@@ -135,7 +135,15 @@ module Informers
135
135
  end
136
136
 
137
137
  def self.construct_session(pretrained_model_name_or_path, file_name, **options)
138
- model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
138
+ prefix = "onnx/"
139
+ if file_name.start_with?("../")
140
+ prefix = ""
141
+ file_name = file_name[3..]
142
+ elsif file_name.start_with?("/")
143
+ prefix = ""
144
+ file_name = file_name[1..]
145
+ end
146
+ model_file_name = "#{prefix}#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
139
147
  path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
140
148
 
141
149
  OnnxRuntime::InferenceSession.new(path)
@@ -229,16 +237,37 @@ module Informers
229
237
  end
230
238
  end
231
239
 
240
+ class MPNetPreTrainedModel < PreTrainedModel
241
+ end
242
+
243
+ class MPNetModel < MPNetPreTrainedModel
244
+ end
245
+
246
+ class XLMRobertaPreTrainedModel < PreTrainedModel
247
+ end
248
+
249
+ class XLMRobertaModel < XLMRobertaPreTrainedModel
250
+ end
251
+
252
+ class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel
253
+ def call(model_inputs)
254
+ SequenceClassifierOutput.new(*super(model_inputs))
255
+ end
256
+ end
257
+
232
258
  MODEL_MAPPING_NAMES_ENCODER_ONLY = {
233
259
  "bert" => ["BertModel", BertModel],
234
260
  "nomic_bert" => ["NomicBertModel", NomicBertModel],
235
261
  "deberta-v2" => ["DebertaV2Model", DebertaV2Model],
236
- "distilbert" => ["DistilBertModel", DistilBertModel]
262
+ "mpnet" => ["MPNetModel", MPNetModel],
263
+ "distilbert" => ["DistilBertModel", DistilBertModel],
264
+ "xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel]
237
265
  }
238
266
 
239
267
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
240
268
  "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
241
- "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
269
+ "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
270
+ "xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]
242
271
  }
243
272
 
244
273
  MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
@@ -249,7 +249,8 @@ module Informers
249
249
  pooling: "none",
250
250
  normalize: false,
251
251
  quantize: false,
252
- precision: "binary"
252
+ precision: "binary",
253
+ model_output: nil
253
254
  )
254
255
  # Run tokenization
255
256
  model_inputs = @tokenizer.(texts,
@@ -258,8 +259,10 @@ module Informers
258
259
  )
259
260
  model_options = {}
260
261
 
261
- # optimization for sentence-transformers/all-MiniLM-L6-v2
262
- if @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize
262
+ if !model_output.nil?
263
+ model_options[:output_names] = Array(model_output)
264
+ elsif @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize
265
+ # optimization for sentence-transformers/all-MiniLM-L6-v2
263
266
  model_options[:output_names] = ["sentence_embedding"]
264
267
  pooling = "none"
265
268
  normalize = false
@@ -271,7 +274,9 @@ module Informers
271
274
  # TODO improve
272
275
  result =
273
276
  if outputs.is_a?(Array)
274
- raise Error, "unexpected outputs" if outputs.size != 1
277
+ # TODO show returned instead of all
278
+ output_names = @model.instance_variable_get(:@session).outputs.map { |v| v[:name] }
279
+ raise Error, "unexpected outputs: #{output_names}" if outputs.size != 1
275
280
  outputs[0]
276
281
  else
277
282
  outputs.logits
@@ -285,6 +290,7 @@ module Informers
285
290
  when "cls"
286
291
  result = result.map(&:first)
287
292
  else
293
+ # TODO raise ArgumentError in 2.0
288
294
  raise Error, "Pooling method '#{pooling}' not supported."
289
295
  end
290
296
 
@@ -304,9 +310,10 @@ module Informers
304
310
  def call(
305
311
  texts,
306
312
  pooling: "mean",
307
- normalize: true
313
+ normalize: true,
314
+ model_output: nil
308
315
  )
309
- super(texts, pooling:, normalize:)
316
+ super(texts, pooling:, normalize:, model_output:)
310
317
  end
311
318
  end
312
319
 
@@ -91,11 +91,23 @@ module Informers
91
91
  class DistilBertTokenizer < PreTrainedTokenizer
92
92
  end
93
93
 
94
+ class RobertaTokenizer < PreTrainedTokenizer
95
+ end
96
+
97
+ class XLMRobertaTokenizer < PreTrainedTokenizer
98
+ end
99
+
100
+ class MPNetTokenizer < PreTrainedTokenizer
101
+ end
102
+
94
103
  class AutoTokenizer
95
104
  TOKENIZER_CLASS_MAPPING = {
96
105
  "BertTokenizer" => BertTokenizer,
97
106
  "DebertaV2Tokenizer" => DebertaV2Tokenizer,
98
- "DistilBertTokenizer" => DistilBertTokenizer
107
+ "DistilBertTokenizer" => DistilBertTokenizer,
108
+ "RobertaTokenizer" => RobertaTokenizer,
109
+ "XLMRobertaTokenizer" => XLMRobertaTokenizer,
110
+ "MPNetTokenizer" => MPNetTokenizer
99
111
  }
100
112
 
101
113
  def self.from_pretrained(
@@ -1,3 +1,3 @@
1
1
  module Informers
2
- VERSION = "1.0.2"
2
+ VERSION = "1.0.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: informers
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.0.2
4
+ version: 1.0.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-08-28 00:00:00.000000000 Z
11
+ date: 2024-08-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: onnxruntime