fastembed 1.0.0 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,217 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Cross-encoder model for reranking query-document pairs
5
+ #
6
+ # Unlike embedding models that encode texts independently, cross-encoders
7
+ # process query-document pairs together to produce relevance scores.
8
+ # This is more accurate but slower (O(n) comparisons vs O(1) with embeddings).
9
+ #
10
+ # @example Basic reranking
11
+ # reranker = Fastembed::TextCrossEncoder.new
12
+ # scores = reranker.rerank(
13
+ # query: "What is machine learning?",
14
+ # documents: ["ML is a subset of AI...", "The weather is nice today"]
15
+ # )
16
+ # # => [0.95, 0.02]
17
+ #
18
+ # @example Get ranked results
19
+ # results = reranker.rerank_with_scores(
20
+ # query: "What is Ruby?",
21
+ # documents: documents
22
+ # )
23
+ # # => [{document: "Ruby is a programming...", score: 0.89, index: 2}, ...]
24
+ #
25
+ class TextCrossEncoder
26
+ include BaseModel
27
+
28
+ # Initialize a cross-encoder model for reranking
29
+ #
30
+ # @param model_name [String] Name of the model to use
31
+ # @param cache_dir [String, nil] Custom cache directory for models
32
+ # @param threads [Integer, nil] Number of threads for ONNX Runtime
33
+ # @param providers [Array<String>, nil] ONNX execution providers
34
+ # @param show_progress [Boolean] Whether to show download progress
35
+ # @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
36
+ # @param local_model_dir [String, nil] Load model from local directory instead of downloading
37
+ # @param model_file [String, nil] Override model file name (e.g., "model.onnx")
38
+ # @param tokenizer_file [String, nil] Override tokenizer file name (e.g., "tokenizer.json")
39
+ def initialize(
40
+ model_name: DEFAULT_RERANKER_MODEL,
41
+ cache_dir: nil,
42
+ threads: nil,
43
+ providers: nil,
44
+ show_progress: true,
45
+ quantization: nil,
46
+ local_model_dir: nil,
47
+ model_file: nil,
48
+ tokenizer_file: nil
49
+ )
50
+ if local_model_dir
51
+ initialize_from_local(
52
+ local_model_dir: local_model_dir,
53
+ model_name: model_name,
54
+ threads: threads,
55
+ providers: providers,
56
+ quantization: quantization,
57
+ model_file: model_file,
58
+ tokenizer_file: tokenizer_file
59
+ )
60
+ else
61
+ initialize_model(
62
+ model_name: model_name,
63
+ cache_dir: cache_dir,
64
+ threads: threads,
65
+ providers: providers,
66
+ show_progress: show_progress,
67
+ quantization: quantization
68
+ )
69
+ end
70
+
71
+ setup_model_and_tokenizer(model_file_override: model_file || quantized_model_file)
72
+ end
73
+
74
+ # Score query-document pairs and return relevance scores
75
+ #
76
+ # @param query [String] The query text
77
+ # @param documents [Array<String>] Documents to score against the query
78
+ # @param batch_size [Integer] Number of pairs to process at once
79
+ # @return [Array<Float>] Relevance scores for each document (higher = more relevant)
80
+ # @raise [ArgumentError] If query or documents is nil, or documents contains nil
81
+ def rerank(query:, documents:, batch_size: 64)
82
+ Validators.validate_rerank_input!(query: query, documents: documents)
83
+ return [] if documents.empty?
84
+
85
+ scores = []
86
+ documents.each_slice(batch_size) do |batch|
87
+ batch_scores = score_pairs(query, batch)
88
+ scores.concat(batch_scores)
89
+ end
90
+ scores
91
+ end
92
+
93
+ # Rerank documents and return sorted results with scores
94
+ #
95
+ # @param query [String] The query text
96
+ # @param documents [Array<String>] Documents to rerank
97
+ # @param top_k [Integer, nil] Return only top K results (nil = all)
98
+ # @param batch_size [Integer] Number of pairs to process at once
99
+ # @return [Array<Hash>] Sorted results with :document, :score, :index keys
100
+ def rerank_with_scores(query:, documents:, top_k: nil, batch_size: 64)
101
+ scores = rerank(query: query, documents: documents, batch_size: batch_size)
102
+
103
+ results = documents.zip(scores).each_with_index.map do |(doc, score), idx|
104
+ { document: doc, score: score, index: idx }
105
+ end
106
+
107
+ results.sort_by! { |r| -r[:score] }
108
+ top_k ? results.first(top_k) : results
109
+ end
110
+
111
+ # Rerank documents asynchronously
112
+ #
113
+ # @param query [String] The query text
114
+ # @param documents [Array<String>] Documents to score against the query
115
+ # @param batch_size [Integer] Number of pairs to process at once
116
+ # @return [Async::Future] Future that resolves to array of scores
117
+ def rerank_async(query:, documents:, batch_size: 64)
118
+ Async::Future.new { rerank(query: query, documents: documents, batch_size: batch_size) }
119
+ end
120
+
121
+ # Rerank documents with scores asynchronously
122
+ #
123
+ # @param query [String] The query text
124
+ # @param documents [Array<String>] Documents to rerank
125
+ # @param top_k [Integer, nil] Return only top K results (nil = all)
126
+ # @param batch_size [Integer] Number of pairs to process at once
127
+ # @return [Async::Future] Future that resolves to sorted results array
128
+ def rerank_with_scores_async(query:, documents:, top_k: nil, batch_size: 64)
129
+ Async::Future.new { rerank_with_scores(query: query, documents: documents, top_k: top_k, batch_size: batch_size) }
130
+ end
131
+
132
+ # List all supported reranker models
133
+ #
134
+ # @return [Array<Hash>] Array of model information hashes
135
+ def self.list_supported_models
136
+ SUPPORTED_RERANKER_MODELS.values.map(&:to_h)
137
+ end
138
+
139
+ private
140
+
141
+ def resolve_model_info(model_name)
142
+ # Check built-in registry first
143
+ info = SUPPORTED_RERANKER_MODELS[model_name]
144
+ return info if info
145
+
146
+ # Check custom registry
147
+ info = CustomModelRegistry.reranker_models[model_name]
148
+ return info if info
149
+
150
+ raise Error, "Unknown reranker model: #{model_name}"
151
+ end
152
+
153
+ def create_local_model_info(model_name:, model_file:, tokenizer_file:)
154
+ RerankerModelInfo.new(
155
+ model_name: model_name,
156
+ description: 'Local reranker model',
157
+ size_in_gb: 0,
158
+ sources: {},
159
+ model_file: model_file || 'model.onnx',
160
+ tokenizer_file: tokenizer_file || 'tokenizer.json'
161
+ )
162
+ end
163
+
164
+ def score_pairs(query, documents)
165
+ encodings = tokenize_pairs(query, documents)
166
+ inputs = prepare_pair_inputs(encodings)
167
+ extract_scores(@session.run(nil, inputs))
168
+ end
169
+
170
+ def tokenize_pairs(query, documents)
171
+ documents.map { |doc| @tokenizer.encode(query, doc) }
172
+ end
173
+
174
+ def prepare_pair_inputs(encodings)
175
+ max_len = encodings.map { |e| e.ids.length }.max
176
+
177
+ input_ids = []
178
+ attention_mask = []
179
+ token_type_ids = []
180
+
181
+ encodings.each do |encoding|
182
+ pad_len = max_len - encoding.ids.length
183
+ input_ids << pad_sequence(encoding.ids, pad_len)
184
+ attention_mask << pad_sequence(encoding.attention_mask, pad_len)
185
+ token_type_ids << pad_sequence(encoding.type_ids, pad_len)
186
+ end
187
+
188
+ inputs = { 'input_ids' => input_ids }
189
+
190
+ # Only add attention_mask if the model expects it
191
+ mask_key = if session_input_names.include?('attention_mask')
192
+ 'attention_mask'
193
+ elsif session_input_names.include?('input_mask')
194
+ 'input_mask'
195
+ end
196
+ inputs[mask_key] = attention_mask if mask_key
197
+
198
+ # Only add token_type_ids if the model expects it
199
+ type_key = if session_input_names.include?('token_type_ids')
200
+ 'token_type_ids'
201
+ elsif session_input_names.include?('segment_ids')
202
+ 'segment_ids'
203
+ end
204
+ inputs[type_key] = token_type_ids if type_key
205
+
206
+ inputs
207
+ end
208
+
209
+ def pad_sequence(sequence, pad_len)
210
+ sequence + ([0] * pad_len)
211
+ end
212
+
213
+ def extract_scores(outputs)
214
+ outputs.first.map { |logit| logit.is_a?(Array) ? logit.first : logit }
215
+ end
216
+ end
217
+ end
@@ -15,8 +15,17 @@ module Fastembed
15
15
  # # Process each vector
16
16
  # end
17
17
  #
18
+ # @example Load from local directory
19
+ # embedding = Fastembed::TextEmbedding.new(
20
+ # local_model_dir: "/path/to/model",
21
+ # model_file: "model.onnx",
22
+ # tokenizer_file: "tokenizer.json"
23
+ # )
24
+ #
18
25
  class TextEmbedding
19
- attr_reader :model_name, :model_info, :dim
26
+ include BaseModel
27
+
28
+ attr_reader :dim
20
29
 
21
30
  # Initialize a text embedding model
22
31
  #
@@ -25,55 +34,84 @@ module Fastembed
25
34
  # @param threads [Integer, nil] Number of threads for ONNX Runtime
26
35
  # @param providers [Array<String>, nil] ONNX execution providers (e.g., ["CoreMLExecutionProvider"])
27
36
  # @param show_progress [Boolean] Whether to show download progress
37
+ # @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
38
+ # @param local_model_dir [String, nil] Load model from local directory instead of downloading
39
+ # @param model_file [String, nil] Override model file name (e.g., "model.onnx")
40
+ # @param tokenizer_file [String, nil] Override tokenizer file name (e.g., "tokenizer.json")
28
41
  def initialize(
29
42
  model_name: DEFAULT_MODEL,
30
43
  cache_dir: nil,
31
44
  threads: nil,
32
45
  providers: nil,
33
- show_progress: true
46
+ show_progress: true,
47
+ quantization: nil,
48
+ local_model_dir: nil,
49
+ model_file: nil,
50
+ tokenizer_file: nil
34
51
  )
35
- @model_name = model_name
36
- @threads = threads
37
- @providers = providers
38
- @show_progress = show_progress
39
-
40
- # Set custom cache directory if provided
41
- ModelManagement.cache_dir = cache_dir if cache_dir
52
+ if local_model_dir
53
+ initialize_from_local(
54
+ local_model_dir: local_model_dir,
55
+ model_name: model_name,
56
+ threads: threads,
57
+ providers: providers,
58
+ quantization: quantization,
59
+ model_file: model_file,
60
+ tokenizer_file: tokenizer_file
61
+ )
62
+ else
63
+ initialize_model(
64
+ model_name: model_name,
65
+ cache_dir: cache_dir,
66
+ threads: threads,
67
+ providers: providers,
68
+ show_progress: show_progress,
69
+ quantization: quantization
70
+ )
71
+ end
42
72
 
43
- # Resolve model info
44
- @model_info = ModelManagement.resolve_model_info(model_name)
45
73
  @dim = @model_info.dim
46
-
47
- # Download and load model
48
- @model_dir = ModelManagement.retrieve_model(model_name, show_progress: show_progress)
49
- @model = OnnxEmbeddingModel.new(@model_info, @model_dir, threads: threads, providers: providers)
74
+ @model = OnnxEmbeddingModel.new(
75
+ @model_info,
76
+ @model_dir,
77
+ threads: threads,
78
+ providers: providers,
79
+ model_file_override: model_file || quantized_model_file
80
+ )
50
81
  end
51
82
 
52
83
  # Generate embeddings for documents
53
84
  #
54
85
  # @param documents [Array<String>, String] Text document(s) to embed
55
86
  # @param batch_size [Integer] Number of documents to process at once
87
+ # @yield [Progress] Optional progress callback called after each batch
56
88
  # @return [Enumerator] Lazy enumerator yielding embedding vectors
57
89
  # @raise [ArgumentError] If documents is nil or contains nil values
58
90
  #
59
- # @example
91
+ # @example Basic usage
60
92
  # vectors = embedding.embed(["Hello", "World"]).to_a
61
93
  # # => [[0.1, 0.2, ...], [0.3, 0.4, ...]]
62
- def embed(documents, batch_size: 256)
63
- raise ArgumentError, 'documents cannot be nil' if documents.nil?
64
-
65
- documents = [documents] if documents.is_a?(String)
94
+ #
95
+ # @example With progress callback
96
+ # embedding.embed(documents, batch_size: 64) do |progress|
97
+ # puts "#{progress.percent}% complete"
98
+ # end.to_a
99
+ #
100
+ def embed(documents, batch_size: 256, &progress_callback)
101
+ documents = Validators.validate_documents!(documents)
66
102
  return Enumerator.new { |_| } if documents.empty?
67
103
 
68
- # Validate all documents
69
- documents.each_with_index do |doc, i|
70
- raise ArgumentError, "document at index #{i} cannot be nil" if doc.nil?
71
- end
104
+ total_batches = (documents.length.to_f / batch_size).ceil
72
105
 
73
106
  Enumerator.new do |yielder|
74
- documents.each_slice(batch_size) do |batch|
107
+ documents.each_slice(batch_size).with_index(1) do |batch, batch_num|
75
108
  embeddings = @model.embed(batch)
76
109
  embeddings.each { |embedding| yielder << embedding }
110
+
111
+ if progress_callback
112
+ progress = Progress.new(current: batch_num, total: total_batches, batch_size: batch_size)
113
+ progress_callback.call(progress)
114
+ end
77
115
  end
78
116
  end
79
117
  end
@@ -100,11 +138,60 @@ module Fastembed
100
138
  embed(prefixed, batch_size: batch_size)
101
139
  end
102
140
 
103
- # List all supported models
141
+ # Generate embeddings asynchronously in a background thread
142
+ #
143
+ # Returns immediately with a Future object. The embedding computation
144
+ # runs in a background thread. Call `value` on the Future to get results.
145
+ #
146
+ # @param documents [Array<String>, String] Text document(s) to embed
147
+ # @param batch_size [Integer] Number of documents to process at once
148
+ # @return [Async::Future] Future that resolves to array of embedding vectors
149
+ #
150
+ # @example Basic async usage
151
+ # future = embedding.embed_async(documents)
152
+ # # ... do other work ...
153
+ # vectors = future.value # blocks until complete
154
+ #
155
+ # @example Parallel embedding of multiple batches
156
+ # futures = documents.each_slice(1000).map do |batch|
157
+ # embedding.embed_async(batch)
158
+ # end
159
+ # all_vectors = futures.flat_map(&:value)
160
+ #
161
+ def embed_async(documents, batch_size: 256)
162
+ Async::Future.new do
163
+ embed(documents, batch_size: batch_size).to_a
164
+ end
165
+ end
166
+
167
+ # Generate query embeddings asynchronously
168
+ #
169
+ # @param queries [Array<String>, String] Query text(s) to embed
170
+ # @param batch_size [Integer] Number of queries to process at once
171
+ # @return [Async::Future] Future that resolves to array of embedding vectors
172
+ def query_embed_async(queries, batch_size: 256)
173
+ Async::Future.new do
174
+ query_embed(queries, batch_size: batch_size).to_a
175
+ end
176
+ end
177
+
178
+ # Generate passage embeddings asynchronously
179
+ #
180
+ # @param passages [Array<String>, String] Passage text(s) to embed
181
+ # @param batch_size [Integer] Number of passages to process at once
182
+ # @return [Async::Future] Future that resolves to array of embedding vectors
183
+ def passage_embed_async(passages, batch_size: 256)
184
+ Async::Future.new do
185
+ passage_embed(passages, batch_size: batch_size).to_a
186
+ end
187
+ end
188
+
189
+ # List all supported models (built-in and custom)
104
190
  #
105
191
  # @return [Array<Hash>] Array of model information hashes
106
192
  def self.list_supported_models
107
- SUPPORTED_MODELS.values.map(&:to_h)
193
+ all_models = SUPPORTED_MODELS.merge(CustomModelRegistry.embedding_models)
194
+ all_models.values.map(&:to_h)
108
195
  end
109
196
 
110
197
  # Get information about a specific model
@@ -112,7 +199,53 @@ module Fastembed
112
199
  # @param model_name [String] Name of the model
113
200
  # @return [Hash, nil] Model information or nil if not found
114
201
  def self.get_model_info(model_name)
115
- SUPPORTED_MODELS[model_name]&.to_h
202
+ info = SUPPORTED_MODELS[model_name] || CustomModelRegistry.embedding_models[model_name]
203
+ info&.to_h
204
+ end
205
+
206
+ private
207
+
208
+ def resolve_model_info(model_name)
209
+ ModelManagement.resolve_model_info(model_name)
210
+ end
211
+
212
+ def create_local_model_info(model_name:, model_file:, tokenizer_file:)
213
+ # Detect dimension from model output shape if possible
214
+ # For now, use a placeholder that will be updated after model load
215
+ ModelInfo.new(
216
+ model_name: model_name,
217
+ dim: detect_model_dimension(model_file) || 384,
218
+ description: 'Local model',
219
+ size_in_gb: 0,
220
+ sources: {},
221
+ model_file: model_file || 'model.onnx',
222
+ tokenizer_file: tokenizer_file || 'tokenizer.json'
223
+ )
224
+ end
225
+
226
+ # Detect embedding dimension from ONNX model output shape
227
+ #
228
+ # @param model_file [String, nil] Model filename to inspect
229
+ # @return [Integer, nil] Detected dimension or nil if detection fails
230
+ def detect_model_dimension(model_file)
231
+ model_path = File.join(@model_dir, model_file || 'model.onnx')
232
+ return nil unless File.exist?(model_path)
233
+
234
+ session = OnnxRuntime::InferenceSession.new(model_path)
235
+ # Look for output shape - usually [batch, seq_len, hidden_size] or [batch, hidden_size]
236
+ output = session.outputs.first
237
+ return nil unless output && output[:shape]
238
+
239
+ shape = output[:shape]
240
+ # Last dimension is usually the embedding dimension
241
+ dim = shape.last
242
+ dim.is_a?(Integer) && dim.positive? ? dim : nil
243
+ rescue OnnxRuntime::Error => e
244
+ warn "Warning: Could not detect model dimension: #{e.message}"
245
+ nil
246
+ rescue StandardError => e
247
+ warn "Warning: Unexpected error detecting model dimension: #{e.class} - #{e.message}"
248
+ nil
116
249
  end
117
250
  end
118
251
  end
@@ -0,0 +1,59 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Input validation helpers for embedding models
5
+ #
6
+ # Provides consistent validation and normalization of documents, queries,
7
+ # and other inputs across all model types.
8
+ #
9
+ # @api private
10
+ #
11
+ module Validators
12
+ class << self
13
+ # Validate and normalize document input
14
+ #
15
+ # Ensures documents are not nil, converts single strings to arrays,
16
+ # and validates that no individual document is nil.
17
+ #
18
+ # @param documents [Array<String>, String, nil] Documents to validate
19
+ # @return [Array<String>] Normalized array of documents
20
+ # @raise [ArgumentError] If documents is nil or contains nil values
21
+ def validate_documents!(documents)
22
+ raise ArgumentError, 'documents cannot be nil' if documents.nil?
23
+
24
+ documents = [documents] if documents.is_a?(String)
25
+
26
+ documents.each_with_index do |doc, i|
27
+ raise ArgumentError, "document at index #{i} cannot be nil" if doc.nil?
28
+ end
29
+
30
+ documents
31
+ end
32
+
33
+ # Validate query and documents for reranking
34
+ #
35
+ # @param query [String, nil] Query to validate
36
+ # @param documents [Array<String>, nil] Documents to validate
37
+ # @return [Array<String>] Validated documents array
38
+ # @raise [ArgumentError] If query or documents is nil, or documents contains nil
39
+ def validate_rerank_input!(query:, documents:)
40
+ raise ArgumentError, 'query cannot be nil' if query.nil?
41
+ raise ArgumentError, 'documents cannot be nil' if documents.nil?
42
+
43
+ documents.each_with_index do |doc, i|
44
+ raise ArgumentError, "document at index #{i} cannot be nil" if doc.nil?
45
+ end
46
+
47
+ documents
48
+ end
49
+
50
+ # Check if documents array is empty
51
+ #
52
+ # @param documents [Array<String>] Documents to check
53
+ # @return [Boolean] True if empty
54
+ def empty?(documents)
55
+ documents.empty?
56
+ end
57
+ end
58
+ end
59
+ end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Fastembed
4
- VERSION = '1.0.0'
4
+ VERSION = '1.1.0'
5
5
  end
data/lib/fastembed.rb CHANGED
@@ -2,13 +2,54 @@
2
2
 
3
3
  require_relative 'fastembed/version'
4
4
 
5
+ # Fastembed - Fast, lightweight text embeddings for Ruby
6
+ #
7
+ # A Ruby port of FastEmbed providing text embeddings using ONNX Runtime.
8
+ # Supports dense embeddings, sparse embeddings (SPLADE), late interaction (ColBERT),
9
+ # and cross-encoder reranking.
10
+ #
11
+ # @example Basic text embedding
12
+ # embedding = Fastembed::TextEmbedding.new
13
+ # vectors = embedding.embed(["Hello world", "Ruby is great"]).to_a
14
+ #
15
+ # @example Reranking documents
16
+ # reranker = Fastembed::TextCrossEncoder.new
17
+ # scores = reranker.rerank(query: "What is ML?", documents: docs)
18
+ #
19
+ # @example Sparse embeddings
20
+ # sparse = Fastembed::TextSparseEmbedding.new
21
+ # embeddings = sparse.embed(["Hello"]).to_a
22
+ #
23
+ # @example Async embedding
24
+ # future = embedding.embed_async(large_document_list)
25
+ # vectors = future.value # blocks until complete
26
+ #
27
+ # @see https://github.com/khasinski/fastembed-rb
28
+ #
5
29
  module Fastembed
30
+ # Base error class for all Fastembed errors
6
31
  class Error < StandardError; end
32
+
33
+ # Raised when model download fails
7
34
  class DownloadError < Error; end
8
35
  end
9
36
 
37
+ require_relative 'fastembed/pooling'
38
+ require_relative 'fastembed/base_model_info'
10
39
  require_relative 'fastembed/model_info'
40
+ require_relative 'fastembed/reranker_model_info'
41
+ require_relative 'fastembed/sparse_model_info'
42
+ require_relative 'fastembed/late_interaction_model_info'
43
+ require_relative 'fastembed/custom_model_registry'
11
44
  require_relative 'fastembed/model_management'
12
- require_relative 'fastembed/pooling'
45
+ require_relative 'fastembed/quantization'
46
+ require_relative 'fastembed/progress'
47
+ require_relative 'fastembed/async'
48
+ require_relative 'fastembed/validators'
49
+ require_relative 'fastembed/base_model'
13
50
  require_relative 'fastembed/onnx_embedding_model'
14
51
  require_relative 'fastembed/text_embedding'
52
+ require_relative 'fastembed/text_cross_encoder'
53
+ require_relative 'fastembed/sparse_embedding'
54
+ require_relative 'fastembed/late_interaction_embedding'
55
+ require_relative 'fastembed/image_embedding'