fastembed 1.0.0 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,20 +5,41 @@ require 'tokenizers'
5
5
 
6
6
  module Fastembed
7
7
  # ONNX-based embedding model wrapper
8
+ #
9
+ # Handles the low-level details of running ONNX inference and tokenization
10
+ # for embedding models. Used internally by TextEmbedding.
11
+ #
12
+ # @api private
13
+ #
8
14
  class OnnxEmbeddingModel
15
+ # @!attribute [r] model_info
16
+ # @return [ModelInfo] Model metadata
17
+ # @!attribute [r] model_dir
18
+ # @return [String] Path to model directory
9
19
  attr_reader :model_info, :model_dir
10
20
 
11
- def initialize(model_info, model_dir, threads: nil, providers: nil)
21
+ # @param model_info [ModelInfo] Model metadata
22
+ # @param model_dir [String] Directory containing model files
23
+ # @param threads [Integer, nil] Number of threads for ONNX Runtime
24
+ # @param providers [Array<String>, nil] ONNX execution providers
25
+ # @param model_file_override [String, nil] Override model file path (for quantized models)
26
+ def initialize(model_info, model_dir, threads: nil, providers: nil, model_file_override: nil)
12
27
  @model_info = model_info
13
28
  @model_dir = model_dir
14
29
  @threads = threads
15
30
  @providers = providers
31
+ @model_file = model_file_override || model_info.model_file
16
32
 
17
33
  load_model
18
34
  load_tokenizer
19
35
  end
20
36
 
21
37
  # Embed a batch of texts
38
+ #
39
+ # Tokenizes the input, runs ONNX inference, and applies pooling.
40
+ #
41
+ # @param texts [Array<String>] Texts to embed
42
+ # @return [Array<Array<Float>>] Embedding vectors
22
43
  def embed(texts)
23
44
  # Tokenize
24
45
  encoded = tokenize(texts)
@@ -38,7 +59,7 @@ module Fastembed
38
59
  private
39
60
 
40
61
  def load_model
41
- model_path = File.join(model_dir, model_info.model_file)
62
+ model_path = File.join(model_dir, @model_file)
42
63
  raise Error, "Model file not found: #{model_path}" unless File.exist?(model_path)
43
64
 
44
65
  session_options = {}
@@ -53,11 +74,11 @@ module Fastembed
53
74
  tokenizer_path = File.join(model_dir, model_info.tokenizer_file)
54
75
  raise Error, "Tokenizer file not found: #{tokenizer_path}" unless File.exist?(tokenizer_path)
55
76
 
56
- @tokenizer = Tokenizers.from_file(tokenizer_path)
77
+ @tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_path)
57
78
 
58
79
  # Configure tokenizer for batch encoding
59
80
  @tokenizer.enable_padding(pad_id: 0, pad_token: '[PAD]')
60
- @tokenizer.enable_truncation(512)
81
+ @tokenizer.enable_truncation(model_info.max_length)
61
82
  end
62
83
 
63
84
  def tokenize(texts)
@@ -1,10 +1,32 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Fastembed
4
- # Pooling strategies for transformer outputs
4
+ # Pooling strategies for transformer model outputs
5
+ #
6
+ # Transforms variable-length token embeddings into fixed-size sentence embeddings.
7
+ # Supports mean pooling (default) and CLS token pooling.
8
+ #
9
+ # @example Apply mean pooling with normalization
10
+ # pooled = Pooling.apply(:mean, token_embeddings, attention_mask)
11
+ #
5
12
  module Pooling
13
+ # Valid pooling strategies
14
+ VALID_STRATEGIES = %i[mean cls].freeze
15
+
16
+ # Check if a pooling strategy is valid
17
+ #
18
+ # @param strategy [Symbol] Pooling strategy to check
19
+ # @return [Boolean] True if valid
20
+ def self.valid?(strategy)
21
+ VALID_STRATEGIES.include?(strategy)
22
+ end
23
+
6
24
  class << self
7
25
  # Mean pooling - averages all token embeddings weighted by attention mask
26
+ #
27
+ # @param token_embeddings [Array<Array<Array<Float>>>] Token embeddings [batch, seq, hidden]
28
+ # @param attention_mask [Array<Array<Integer>>] Attention mask [batch, seq]
29
+ # @return [Array<Array<Float>>] Pooled embeddings [batch, hidden]
8
30
  def mean_pooling(token_embeddings, attention_mask)
9
31
  # token_embeddings: [batch_size, seq_len, hidden_size]
10
32
  # attention_mask: [batch_size, seq_len]
@@ -40,11 +62,18 @@ module Fastembed
40
62
  end
41
63
 
42
64
  # CLS pooling - uses the [CLS] token embedding (first token)
65
+ #
66
+ # @param token_embeddings [Array<Array<Array<Float>>>] Token embeddings [batch, seq, hidden]
67
+ # @param _attention_mask [Array<Array<Integer>>] Attention mask (unused)
68
+ # @return [Array<Array<Float>>] Pooled embeddings [batch, hidden]
43
69
  def cls_pooling(token_embeddings, _attention_mask)
44
70
  token_embeddings.map { |batch| batch[0] }
45
71
  end
46
72
 
47
- # L2 normalize vectors
73
+ # L2 normalize vectors to unit length
74
+ #
75
+ # @param vectors [Array<Array<Float>>] Vectors to normalize
76
+ # @return [Array<Array<Float>>] Normalized vectors
48
77
  def normalize(vectors)
49
78
  vectors.map do |vector|
50
79
  norm = Math.sqrt(vector.sum { |v| v * v })
@@ -53,7 +82,14 @@ module Fastembed
53
82
  end
54
83
  end
55
84
 
56
- # Apply pooling based on strategy
85
+ # Apply pooling strategy to token embeddings
86
+ #
87
+ # @param strategy [Symbol] Pooling strategy (:mean or :cls)
88
+ # @param token_embeddings [Array] Token embeddings from model
89
+ # @param attention_mask [Array] Attention mask
90
+ # @param should_normalize [Boolean] Whether to L2 normalize output
91
+ # @return [Array<Array<Float>>] Pooled embeddings
92
+ # @raise [ArgumentError] If unknown pooling strategy
57
93
  def apply(strategy, token_embeddings, attention_mask, should_normalize: true)
58
94
  pooled = case strategy
59
95
  when :mean
@@ -0,0 +1,52 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Progress information for batch operations
5
+ # Passed to progress callbacks during embedding
6
+ class Progress
7
+ attr_reader :current, :total, :batch_size
8
+
9
+ # @param current [Integer] Current batch number (1-indexed)
10
+ # @param total [Integer] Total number of batches
11
+ # @param batch_size [Integer] Size of each batch
12
+ def initialize(current:, total:, batch_size:)
13
+ @current = current
14
+ @total = total
15
+ @batch_size = batch_size
16
+ end
17
+
18
+ # Percentage complete (0.0 to 1.0)
19
+ # @return [Float]
20
+ def percentage
21
+ return 1.0 if total.zero?
22
+
23
+ current.to_f / total
24
+ end
25
+
26
+ # Percentage as integer (0 to 100)
27
+ # @return [Integer]
28
+ def percent
29
+ (percentage * 100).round
30
+ end
31
+
32
+ # Number of documents processed so far
33
+ # @return [Integer]
34
+ def documents_processed
35
+ current * batch_size
36
+ end
37
+
38
+ # Check if processing is complete
39
+ # @return [Boolean]
40
+ def complete?
41
+ current >= total
42
+ end
43
+
44
+ def to_s
45
+ "Progress(#{current}/#{total}, #{percent}%)"
46
+ end
47
+
48
+ def inspect
49
+ to_s
50
+ end
51
+ end
52
+ end
@@ -0,0 +1,75 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Quantization options for ONNX models
5
+ # Different quantization levels trade off model size/speed vs accuracy
6
+ module Quantization
7
+ # Available quantization types
8
+ TYPES = {
9
+ fp32: {
10
+ suffix: '',
11
+ description: 'Full precision (32-bit float)',
12
+ size_multiplier: 1.0
13
+ },
14
+ fp16: {
15
+ suffix: '_fp16',
16
+ description: 'Half precision (16-bit float)',
17
+ size_multiplier: 0.5
18
+ },
19
+ int8: {
20
+ suffix: '_int8',
21
+ description: 'Dynamic int8 quantization',
22
+ size_multiplier: 0.25
23
+ },
24
+ uint8: {
25
+ suffix: '_uint8',
26
+ description: 'Dynamic uint8 quantization',
27
+ size_multiplier: 0.25
28
+ },
29
+ q4: {
30
+ suffix: '_q4',
31
+ description: '4-bit quantization',
32
+ size_multiplier: 0.125
33
+ }
34
+ }.freeze
35
+
36
+ # Default quantization type
37
+ DEFAULT = :fp32
38
+
39
+ class << self
40
+ # Check if a quantization type is valid
41
+ # @param type [Symbol] Quantization type
42
+ # @return [Boolean]
43
+ def valid?(type)
44
+ TYPES.key?(type)
45
+ end
46
+
47
+ # Get the model file suffix for a quantization type
48
+ # @param type [Symbol] Quantization type
49
+ # @return [String] Suffix to append to model filename
50
+ def suffix(type)
51
+ TYPES.dig(type, :suffix) || ''
52
+ end
53
+
54
+ # Get quantized model filename
55
+ # @param base_file [String] Base model file (e.g., "onnx/model.onnx")
56
+ # @param type [Symbol] Quantization type
57
+ # @return [String] Quantized model filename
58
+ def model_file(base_file, type)
59
+ return base_file if type == :fp32 || type.nil?
60
+
61
+ ext = File.extname(base_file)
62
+ base = base_file.chomp(ext)
63
+ "#{base}#{suffix(type)}#{ext}"
64
+ end
65
+
66
+ # List available quantization types
67
+ # @return [Array<Hash>] Array of quantization info
68
+ def list
69
+ TYPES.map do |type, info|
70
+ { type: type, description: info[:description], size_multiplier: info[:size_multiplier] }
71
+ end
72
+ end
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Model information for reranker/cross-encoder models
5
+ #
6
+ # Cross-encoders process query-document pairs together rather than
7
+ # encoding them separately, enabling more accurate relevance scoring.
8
+ #
9
+ # @example Access reranker info
10
+ # info = Fastembed::SUPPORTED_RERANKER_MODELS['BAAI/bge-reranker-base']
11
+ # info.model_name # => "BAAI/bge-reranker-base"
12
+ #
13
+ class RerankerModelInfo
14
+ include BaseModelInfo
15
+
16
+ # Create a new RerankerModelInfo instance
17
+ #
18
+ # @param model_name [String] Full model identifier
19
+ # @param description [String] Human-readable description
20
+ # @param size_in_gb [Float] Model size in GB
21
+ # @param sources [Hash] Source repositories
22
+ # @param model_file [String] Path to ONNX model file
23
+ # @param tokenizer_file [String] Path to tokenizer file
24
+ def initialize(
25
+ model_name:,
26
+ description:,
27
+ size_in_gb:,
28
+ sources:,
29
+ model_file: 'onnx/model.onnx',
30
+ tokenizer_file: 'tokenizer.json'
31
+ )
32
+ initialize_base(
33
+ model_name: model_name,
34
+ description: description,
35
+ size_in_gb: size_in_gb,
36
+ sources: sources,
37
+ model_file: model_file,
38
+ tokenizer_file: tokenizer_file
39
+ )
40
+ end
41
+
42
+ # Convert to hash representation
43
+ # @return [Hash] Model info as a hash
44
+ def to_h
45
+ {
46
+ model_name: model_name,
47
+ description: description,
48
+ size_in_gb: size_in_gb,
49
+ sources: sources,
50
+ model_file: model_file,
51
+ tokenizer_file: tokenizer_file
52
+ }
53
+ end
54
+ end
55
+
56
+ # Registry of supported reranker models
57
+ SUPPORTED_RERANKER_MODELS = {
58
+ 'BAAI/bge-reranker-base' => RerankerModelInfo.new(
59
+ model_name: 'BAAI/bge-reranker-base',
60
+ description: 'BGE reranker base model for relevance scoring',
61
+ size_in_gb: 1.11,
62
+ sources: { hf: 'Xenova/bge-reranker-base' }
63
+ ),
64
+ 'BAAI/bge-reranker-large' => RerankerModelInfo.new(
65
+ model_name: 'BAAI/bge-reranker-large',
66
+ description: 'BGE reranker large model for higher accuracy',
67
+ size_in_gb: 2.24,
68
+ sources: { hf: 'Xenova/bge-reranker-large' }
69
+ ),
70
+ 'cross-encoder/ms-marco-MiniLM-L-6-v2' => RerankerModelInfo.new(
71
+ model_name: 'cross-encoder/ms-marco-MiniLM-L-6-v2',
72
+ description: 'Lightweight MS MARCO cross-encoder',
73
+ size_in_gb: 0.08,
74
+ sources: { hf: 'Xenova/ms-marco-MiniLM-L-6-v2' }
75
+ ),
76
+ 'cross-encoder/ms-marco-MiniLM-L-12-v2' => RerankerModelInfo.new(
77
+ model_name: 'cross-encoder/ms-marco-MiniLM-L-12-v2',
78
+ description: 'MS MARCO cross-encoder with better accuracy',
79
+ size_in_gb: 0.12,
80
+ sources: { hf: 'Xenova/ms-marco-MiniLM-L-12-v2' }
81
+ ),
82
+ 'jinaai/jina-reranker-v1-turbo-en' => RerankerModelInfo.new(
83
+ model_name: 'jinaai/jina-reranker-v1-turbo-en',
84
+ description: 'Fast Jina reranker for English',
85
+ size_in_gb: 0.13,
86
+ sources: { hf: 'jinaai/jina-reranker-v1-turbo-en' }
87
+ )
88
+ }.freeze
89
+
90
+ DEFAULT_RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
91
+ end
@@ -0,0 +1,261 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Represents a sparse embedding vector
5
+ # Contains indices and their corresponding values (non-zero only)
6
+ class SparseEmbedding
7
+ attr_reader :indices, :values
8
+
9
+ def initialize(indices:, values:)
10
+ @indices = indices
11
+ @values = values
12
+ end
13
+
14
+ # Convert to hash format {index => value}
15
+ def to_h
16
+ indices.zip(values).to_h
17
+ end
18
+
19
+ # Number of non-zero elements
20
+ def nnz
21
+ indices.length
22
+ end
23
+
24
+ def to_s
25
+ "SparseEmbedding(nnz=#{nnz})"
26
+ end
27
+
28
+ def inspect
29
+ to_s
30
+ end
31
+ end
32
+
33
+ # Sparse text embedding using SPLADE models
34
+ #
35
+ # SPLADE (Sparse Lexical and Expansion) models produce sparse vectors
36
+ # where each dimension corresponds to a vocabulary token.
37
+ # These are useful for hybrid search combining with dense embeddings.
38
+ #
39
+ # @example Basic usage
40
+ # sparse = Fastembed::TextSparseEmbedding.new
41
+ # embeddings = sparse.embed(["Hello world"]).to_a
42
+ # embeddings.first.indices # => [101, 7592, 2088, ...]
43
+ # embeddings.first.values # => [0.45, 1.23, 0.89, ...]
44
+ #
45
+ class TextSparseEmbedding
46
+ include BaseModel
47
+
48
+ # Initialize a sparse embedding model
49
+ #
50
+ # @param model_name [String] Name of the model to use
51
+ # @param cache_dir [String, nil] Custom cache directory for models
52
+ # @param threads [Integer, nil] Number of threads for ONNX Runtime
53
+ # @param providers [Array<String>, nil] ONNX execution providers
54
+ # @param show_progress [Boolean] Whether to show download progress
55
+ # @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
56
+ # @param local_model_dir [String, nil] Load model from local directory instead of downloading
57
+ # @param model_file [String, nil] Override model file name (e.g., "model.onnx")
58
+ # @param tokenizer_file [String, nil] Override tokenizer file name (e.g., "tokenizer.json")
59
+ def initialize(
60
+ model_name: DEFAULT_SPARSE_MODEL,
61
+ cache_dir: nil,
62
+ threads: nil,
63
+ providers: nil,
64
+ show_progress: true,
65
+ quantization: nil,
66
+ local_model_dir: nil,
67
+ model_file: nil,
68
+ tokenizer_file: nil
69
+ )
70
+ if local_model_dir
71
+ initialize_from_local(
72
+ local_model_dir: local_model_dir,
73
+ model_name: model_name,
74
+ threads: threads,
75
+ providers: providers,
76
+ quantization: quantization,
77
+ model_file: model_file,
78
+ tokenizer_file: tokenizer_file
79
+ )
80
+ else
81
+ initialize_model(
82
+ model_name: model_name,
83
+ cache_dir: cache_dir,
84
+ threads: threads,
85
+ providers: providers,
86
+ show_progress: show_progress,
87
+ quantization: quantization
88
+ )
89
+ end
90
+
91
+ setup_model_and_tokenizer(model_file_override: model_file || quantized_model_file)
92
+ end
93
+
94
+ # Generate sparse embeddings for documents
95
+ #
96
+ # @param documents [Array<String>, String] Text document(s) to embed
97
+ # @param batch_size [Integer] Number of documents to process at once
98
+ # @yield [Progress] Optional progress callback called after each batch
99
+ # @return [Enumerator] Lazy enumerator yielding SparseEmbedding objects
100
+ def embed(documents, batch_size: 32, &progress_callback)
101
+ documents = Validators.validate_documents!(documents)
102
+ return Enumerator.new { |_| } if documents.empty?
103
+
104
+ total_batches = (documents.length.to_f / batch_size).ceil
105
+
106
+ Enumerator.new do |yielder|
107
+ documents.each_slice(batch_size).with_index(1) do |batch, batch_num|
108
+ sparse_embeddings = compute_sparse_embeddings(batch)
109
+ sparse_embeddings.each { |emb| yielder << emb }
110
+
111
+ if progress_callback
112
+ progress = Progress.new(current: batch_num, total: total_batches, batch_size: batch_size)
113
+ progress_callback.call(progress)
114
+ end
115
+ end
116
+ end
117
+ end
118
+
119
+ # Generate sparse embeddings for queries (with "query: " prefix)
120
+ #
121
+ # @param queries [Array<String>, String] Query text(s) to embed
122
+ # @param batch_size [Integer] Number of queries to process at once
123
+ # @return [Enumerator] Lazy enumerator yielding SparseEmbedding objects
124
+ def query_embed(queries, batch_size: 32)
125
+ queries = [queries] if queries.is_a?(String)
126
+ prefixed = queries.map { |q| "query: #{q}" }
127
+ embed(prefixed, batch_size: batch_size)
128
+ end
129
+
130
+ # Generate sparse embeddings for passages
131
+ #
132
+ # @param passages [Array<String>, String] Passage text(s) to embed
133
+ # @param batch_size [Integer] Number of passages to process at once
134
+ # @return [Enumerator] Lazy enumerator yielding SparseEmbedding objects
135
+ def passage_embed(passages, batch_size: 32)
136
+ passages = [passages] if passages.is_a?(String)
137
+ embed(passages, batch_size: batch_size)
138
+ end
139
+
140
+ # Generate sparse embeddings asynchronously
141
+ #
142
+ # @param documents [Array<String>, String] Text document(s) to embed
143
+ # @param batch_size [Integer] Number of documents to process at once
144
+ # @return [Async::Future] Future that resolves to array of SparseEmbedding objects
145
+ def embed_async(documents, batch_size: 32)
146
+ Async::Future.new { embed(documents, batch_size: batch_size).to_a }
147
+ end
148
+
149
+ # Generate query embeddings asynchronously
150
+ #
151
+ # @param queries [Array<String>, String] Query text(s) to embed
152
+ # @param batch_size [Integer] Number of queries to process at once
153
+ # @return [Async::Future] Future that resolves to array of SparseEmbedding objects
154
+ def query_embed_async(queries, batch_size: 32)
155
+ Async::Future.new { query_embed(queries, batch_size: batch_size).to_a }
156
+ end
157
+
158
+ # Generate passage embeddings asynchronously
159
+ #
160
+ # @param passages [Array<String>, String] Passage text(s) to embed
161
+ # @param batch_size [Integer] Number of passages to process at once
162
+ # @return [Async::Future] Future that resolves to array of SparseEmbedding objects
163
+ def passage_embed_async(passages, batch_size: 32)
164
+ Async::Future.new { passage_embed(passages, batch_size: batch_size).to_a }
165
+ end
166
+
167
+ # List all supported sparse models
168
+ #
169
+ # @return [Array<Hash>] Array of model information hashes
170
+ def self.list_supported_models
171
+ SUPPORTED_SPARSE_MODELS.values.map(&:to_h)
172
+ end
173
+
174
+ private
175
+
176
+ def resolve_model_info(model_name)
177
+ # Check built-in registry first
178
+ info = SUPPORTED_SPARSE_MODELS[model_name]
179
+ return info if info
180
+
181
+ # Check custom registry
182
+ info = CustomModelRegistry.sparse_models[model_name]
183
+ return info if info
184
+
185
+ raise Error, "Unknown sparse model: #{model_name}"
186
+ end
187
+
188
+ def create_local_model_info(model_name:, model_file:, tokenizer_file:)
189
+ SparseModelInfo.new(
190
+ model_name: model_name,
191
+ description: 'Local sparse model',
192
+ size_in_gb: 0,
193
+ sources: {},
194
+ model_file: model_file || 'model.onnx',
195
+ tokenizer_file: tokenizer_file || 'tokenizer.json'
196
+ )
197
+ end
198
+
199
+ def compute_sparse_embeddings(texts)
200
+ prepared = tokenize_and_prepare(texts)
201
+ outputs = @session.run(nil, prepared[:inputs])
202
+ logits = extract_logits(outputs)
203
+
204
+ # Convert to sparse embeddings
205
+ texts.length.times.map do |i|
206
+ create_sparse_embedding(logits[i], prepared[:attention_mask][i])
207
+ end
208
+ end
209
+
210
+ def extract_logits(outputs)
211
+ # SPLADE outputs logits for each token position
212
+ if outputs.is_a?(Hash)
213
+ outputs['logits'] || outputs.values.first
214
+ else
215
+ outputs.first
216
+ end
217
+ end
218
+
219
+ # Transform token logits into sparse embedding using SPLADE algorithm
220
+ #
221
+ # SPLADE (Sparse Lexical AnD Expansion) uses a log-saturation function:
222
+ # weight = log(1 + ReLU(logit))
223
+ #
224
+ # This ensures:
225
+ # - ReLU removes negative activations (irrelevant terms)
226
+ # - log1p provides saturation to prevent any term from dominating
227
+ # - Max-pooling across positions captures the strongest signal per vocabulary term
228
+ #
229
+ # @param token_logits [Array<Array<Float>>] Logits for each token position
230
+ # @param attention_mask [Array<Integer>] Mask indicating valid positions
231
+ # @return [SparseEmbedding] Sparse vector with vocabulary indices and weights
232
+ def create_sparse_embedding(token_logits, attention_mask)
233
+ vocab_size = token_logits.first.length
234
+ max_weights = Array.new(vocab_size, 0.0)
235
+
236
+ token_logits.each_with_index do |logits, pos|
237
+ next if attention_mask[pos].zero?
238
+
239
+ logits.each_with_index do |logit, vocab_idx|
240
+ # SPLADE transformation: log(1 + ReLU(x))
241
+ activated = logit.positive? ? logit : 0.0
242
+ weight = Math.log(1.0 + activated)
243
+ max_weights[vocab_idx] = weight if weight > max_weights[vocab_idx]
244
+ end
245
+ end
246
+
247
+ # Extract non-zero indices and values
248
+ indices = []
249
+ values = []
250
+
251
+ max_weights.each_with_index do |weight, idx|
252
+ if weight.positive?
253
+ indices << idx
254
+ values << weight
255
+ end
256
+ end
257
+
258
+ SparseEmbedding.new(indices: indices, values: values)
259
+ end
260
+ end
261
+ end
@@ -0,0 +1,80 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Model information for sparse embedding models (SPLADE, etc.)
5
+ #
6
+ # Sparse models like SPLADE produce vectors where most dimensions are zero,
7
+ # with non-zero values corresponding to vocabulary tokens. These are useful
8
+ # for hybrid search combining with dense embeddings.
9
+ #
10
+ # @example Access sparse model info
11
+ # info = Fastembed::SUPPORTED_SPARSE_MODELS['prithivida/Splade_PP_en_v1']
12
+ # info.model_name # => "prithivida/Splade_PP_en_v1"
13
+ #
14
+ class SparseModelInfo
15
+ include BaseModelInfo
16
+
17
+ # Create a new SparseModelInfo instance
18
+ #
19
+ # @param model_name [String] Full model identifier
20
+ # @param description [String] Human-readable description
21
+ # @param size_in_gb [Float] Model size in GB
22
+ # @param sources [Hash] Source repositories
23
+ # @param model_file [String] Path to ONNX model file
24
+ # @param tokenizer_file [String] Path to tokenizer file
25
+ # @param max_length [Integer] Maximum sequence length
26
+ def initialize(
27
+ model_name:,
28
+ description:,
29
+ size_in_gb:,
30
+ sources:,
31
+ model_file: 'onnx/model.onnx',
32
+ tokenizer_file: 'tokenizer.json',
33
+ max_length: BaseModelInfo::DEFAULT_MAX_LENGTH
34
+ )
35
+ initialize_base(
36
+ model_name: model_name,
37
+ description: description,
38
+ size_in_gb: size_in_gb,
39
+ sources: sources,
40
+ model_file: model_file,
41
+ tokenizer_file: tokenizer_file,
42
+ max_length: max_length
43
+ )
44
+ end
45
+
46
+ # Convert to hash representation
47
+ # @return [Hash] Model info as a hash
48
+ def to_h
49
+ {
50
+ model_name: model_name,
51
+ description: description,
52
+ size_in_gb: size_in_gb,
53
+ sources: sources,
54
+ model_file: model_file,
55
+ tokenizer_file: tokenizer_file,
56
+ max_length: max_length
57
+ }
58
+ end
59
+ end
60
+
61
+ # Registry of supported sparse embedding models
62
+ SUPPORTED_SPARSE_MODELS = {
63
+ 'prithivida/Splade_PP_en_v1' => SparseModelInfo.new(
64
+ model_name: 'prithivida/Splade_PP_en_v1',
65
+ description: 'SPLADE++ model for sparse text retrieval',
66
+ size_in_gb: 0.53,
67
+ sources: { hf: 'prithivida/Splade_PP_en_v1' }
68
+ # Uses default model_file: 'onnx/model.onnx'
69
+ ),
70
+ 'prithivida/Splade_PP_en_v2' => SparseModelInfo.new(
71
+ model_name: 'prithivida/Splade_PP_en_v2',
72
+ description: 'SPLADE++ v2 with improved performance',
73
+ size_in_gb: 0.53,
74
+ sources: { hf: 'prithivida/Splade_PP_en_v2' }
75
+ # Uses default model_file: 'onnx/model.onnx'
76
+ )
77
+ }.freeze
78
+
79
+ DEFAULT_SPARSE_MODEL = 'prithivida/Splade_PP_en_v1'
80
+ end