fastembed 1.0.0 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +1 -0
- data/.yardopts +6 -0
- data/BENCHMARKS.md +124 -1
- data/CHANGELOG.md +14 -0
- data/README.md +395 -74
- data/benchmark/compare_all.rb +167 -0
- data/benchmark/compare_python.py +60 -0
- data/benchmark/memory_profile.rb +70 -0
- data/benchmark/profile.rb +198 -0
- data/benchmark/reranker_benchmark.rb +158 -0
- data/exe/fastembed +6 -0
- data/fastembed.gemspec +3 -0
- data/lib/fastembed/async.rb +193 -0
- data/lib/fastembed/base_model.rb +247 -0
- data/lib/fastembed/base_model_info.rb +61 -0
- data/lib/fastembed/cli.rb +745 -0
- data/lib/fastembed/custom_model_registry.rb +255 -0
- data/lib/fastembed/image_embedding.rb +313 -0
- data/lib/fastembed/late_interaction_embedding.rb +260 -0
- data/lib/fastembed/late_interaction_model_info.rb +91 -0
- data/lib/fastembed/model_info.rb +59 -19
- data/lib/fastembed/model_management.rb +82 -23
- data/lib/fastembed/onnx_embedding_model.rb +25 -4
- data/lib/fastembed/pooling.rb +39 -3
- data/lib/fastembed/progress.rb +52 -0
- data/lib/fastembed/quantization.rb +75 -0
- data/lib/fastembed/reranker_model_info.rb +91 -0
- data/lib/fastembed/sparse_embedding.rb +261 -0
- data/lib/fastembed/sparse_model_info.rb +80 -0
- data/lib/fastembed/text_cross_encoder.rb +217 -0
- data/lib/fastembed/text_embedding.rb +161 -28
- data/lib/fastembed/validators.rb +59 -0
- data/lib/fastembed/version.rb +1 -1
- data/lib/fastembed.rb +42 -1
- data/plan.md +257 -0
- data/scripts/verify_models.rb +229 -0
- metadata +70 -3
|
@@ -5,20 +5,41 @@ require 'tokenizers'
|
|
|
5
5
|
|
|
6
6
|
module Fastembed
|
|
7
7
|
# ONNX-based embedding model wrapper
|
|
8
|
+
#
|
|
9
|
+
# Handles the low-level details of running ONNX inference and tokenization
|
|
10
|
+
# for embedding models. Used internally by TextEmbedding.
|
|
11
|
+
#
|
|
12
|
+
# @api private
|
|
13
|
+
#
|
|
8
14
|
class OnnxEmbeddingModel
|
|
15
|
+
# @!attribute [r] model_info
|
|
16
|
+
# @return [ModelInfo] Model metadata
|
|
17
|
+
# @!attribute [r] model_dir
|
|
18
|
+
# @return [String] Path to model directory
|
|
9
19
|
attr_reader :model_info, :model_dir
|
|
10
20
|
|
|
11
|
-
|
|
21
|
+
# @param model_info [ModelInfo] Model metadata
|
|
22
|
+
# @param model_dir [String] Directory containing model files
|
|
23
|
+
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
24
|
+
# @param providers [Array<String>, nil] ONNX execution providers
|
|
25
|
+
# @param model_file_override [String, nil] Override model file path (for quantized models)
|
|
26
|
+
def initialize(model_info, model_dir, threads: nil, providers: nil, model_file_override: nil)
|
|
12
27
|
@model_info = model_info
|
|
13
28
|
@model_dir = model_dir
|
|
14
29
|
@threads = threads
|
|
15
30
|
@providers = providers
|
|
31
|
+
@model_file = model_file_override || model_info.model_file
|
|
16
32
|
|
|
17
33
|
load_model
|
|
18
34
|
load_tokenizer
|
|
19
35
|
end
|
|
20
36
|
|
|
21
37
|
# Embed a batch of texts
|
|
38
|
+
#
|
|
39
|
+
# Tokenizes the input, runs ONNX inference, and applies pooling.
|
|
40
|
+
#
|
|
41
|
+
# @param texts [Array<String>] Texts to embed
|
|
42
|
+
# @return [Array<Array<Float>>] Embedding vectors
|
|
22
43
|
def embed(texts)
|
|
23
44
|
# Tokenize
|
|
24
45
|
encoded = tokenize(texts)
|
|
@@ -38,7 +59,7 @@ module Fastembed
|
|
|
38
59
|
private
|
|
39
60
|
|
|
40
61
|
def load_model
|
|
41
|
-
model_path = File.join(model_dir,
|
|
62
|
+
model_path = File.join(model_dir, @model_file)
|
|
42
63
|
raise Error, "Model file not found: #{model_path}" unless File.exist?(model_path)
|
|
43
64
|
|
|
44
65
|
session_options = {}
|
|
@@ -53,11 +74,11 @@ module Fastembed
|
|
|
53
74
|
tokenizer_path = File.join(model_dir, model_info.tokenizer_file)
|
|
54
75
|
raise Error, "Tokenizer file not found: #{tokenizer_path}" unless File.exist?(tokenizer_path)
|
|
55
76
|
|
|
56
|
-
@tokenizer = Tokenizers.from_file(tokenizer_path)
|
|
77
|
+
@tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_path)
|
|
57
78
|
|
|
58
79
|
# Configure tokenizer for batch encoding
|
|
59
80
|
@tokenizer.enable_padding(pad_id: 0, pad_token: '[PAD]')
|
|
60
|
-
@tokenizer.enable_truncation(
|
|
81
|
+
@tokenizer.enable_truncation(model_info.max_length)
|
|
61
82
|
end
|
|
62
83
|
|
|
63
84
|
def tokenize(texts)
|
data/lib/fastembed/pooling.rb
CHANGED
|
@@ -1,10 +1,32 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
3
|
module Fastembed
|
|
4
|
-
# Pooling strategies for transformer outputs
|
|
4
|
+
# Pooling strategies for transformer model outputs
|
|
5
|
+
#
|
|
6
|
+
# Transforms variable-length token embeddings into fixed-size sentence embeddings.
|
|
7
|
+
# Supports mean pooling (default) and CLS token pooling.
|
|
8
|
+
#
|
|
9
|
+
# @example Apply mean pooling with normalization
|
|
10
|
+
# pooled = Pooling.apply(:mean, token_embeddings, attention_mask)
|
|
11
|
+
#
|
|
5
12
|
module Pooling
|
|
13
|
+
# Valid pooling strategies
|
|
14
|
+
VALID_STRATEGIES = %i[mean cls].freeze
|
|
15
|
+
|
|
16
|
+
# Check if a pooling strategy is valid
|
|
17
|
+
#
|
|
18
|
+
# @param strategy [Symbol] Pooling strategy to check
|
|
19
|
+
# @return [Boolean] True if valid
|
|
20
|
+
def self.valid?(strategy)
|
|
21
|
+
VALID_STRATEGIES.include?(strategy)
|
|
22
|
+
end
|
|
23
|
+
|
|
6
24
|
class << self
|
|
7
25
|
# Mean pooling - averages all token embeddings weighted by attention mask
|
|
26
|
+
#
|
|
27
|
+
# @param token_embeddings [Array<Array<Array<Float>>>] Token embeddings [batch, seq, hidden]
|
|
28
|
+
# @param attention_mask [Array<Array<Integer>>] Attention mask [batch, seq]
|
|
29
|
+
# @return [Array<Array<Float>>] Pooled embeddings [batch, hidden]
|
|
8
30
|
def mean_pooling(token_embeddings, attention_mask)
|
|
9
31
|
# token_embeddings: [batch_size, seq_len, hidden_size]
|
|
10
32
|
# attention_mask: [batch_size, seq_len]
|
|
@@ -40,11 +62,18 @@ module Fastembed
|
|
|
40
62
|
end
|
|
41
63
|
|
|
42
64
|
# CLS pooling - uses the [CLS] token embedding (first token)
|
|
65
|
+
#
|
|
66
|
+
# @param token_embeddings [Array<Array<Array<Float>>>] Token embeddings [batch, seq, hidden]
|
|
67
|
+
# @param _attention_mask [Array<Array<Integer>>] Attention mask (unused)
|
|
68
|
+
# @return [Array<Array<Float>>] Pooled embeddings [batch, hidden]
|
|
43
69
|
def cls_pooling(token_embeddings, _attention_mask)
|
|
44
70
|
token_embeddings.map { |batch| batch[0] }
|
|
45
71
|
end
|
|
46
72
|
|
|
47
|
-
# L2 normalize vectors
|
|
73
|
+
# L2 normalize vectors to unit length
|
|
74
|
+
#
|
|
75
|
+
# @param vectors [Array<Array<Float>>] Vectors to normalize
|
|
76
|
+
# @return [Array<Array<Float>>] Normalized vectors
|
|
48
77
|
def normalize(vectors)
|
|
49
78
|
vectors.map do |vector|
|
|
50
79
|
norm = Math.sqrt(vector.sum { |v| v * v })
|
|
@@ -53,7 +82,14 @@ module Fastembed
|
|
|
53
82
|
end
|
|
54
83
|
end
|
|
55
84
|
|
|
56
|
-
# Apply pooling
|
|
85
|
+
# Apply pooling strategy to token embeddings
|
|
86
|
+
#
|
|
87
|
+
# @param strategy [Symbol] Pooling strategy (:mean or :cls)
|
|
88
|
+
# @param token_embeddings [Array] Token embeddings from model
|
|
89
|
+
# @param attention_mask [Array] Attention mask
|
|
90
|
+
# @param should_normalize [Boolean] Whether to L2 normalize output
|
|
91
|
+
# @return [Array<Array<Float>>] Pooled embeddings
|
|
92
|
+
# @raise [ArgumentError] If unknown pooling strategy
|
|
57
93
|
def apply(strategy, token_embeddings, attention_mask, should_normalize: true)
|
|
58
94
|
pooled = case strategy
|
|
59
95
|
when :mean
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Progress information for batch operations
|
|
5
|
+
# Passed to progress callbacks during embedding
|
|
6
|
+
class Progress
|
|
7
|
+
attr_reader :current, :total, :batch_size
|
|
8
|
+
|
|
9
|
+
# @param current [Integer] Current batch number (1-indexed)
|
|
10
|
+
# @param total [Integer] Total number of batches
|
|
11
|
+
# @param batch_size [Integer] Size of each batch
|
|
12
|
+
def initialize(current:, total:, batch_size:)
|
|
13
|
+
@current = current
|
|
14
|
+
@total = total
|
|
15
|
+
@batch_size = batch_size
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Percentage complete (0.0 to 1.0)
|
|
19
|
+
# @return [Float]
|
|
20
|
+
def percentage
|
|
21
|
+
return 1.0 if total.zero?
|
|
22
|
+
|
|
23
|
+
current.to_f / total
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
# Percentage as integer (0 to 100)
|
|
27
|
+
# @return [Integer]
|
|
28
|
+
def percent
|
|
29
|
+
(percentage * 100).round
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
# Number of documents processed so far
|
|
33
|
+
# @return [Integer]
|
|
34
|
+
def documents_processed
|
|
35
|
+
current * batch_size
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Check if processing is complete
|
|
39
|
+
# @return [Boolean]
|
|
40
|
+
def complete?
|
|
41
|
+
current >= total
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def to_s
|
|
45
|
+
"Progress(#{current}/#{total}, #{percent}%)"
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def inspect
|
|
49
|
+
to_s
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
end
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Quantization options for ONNX models
|
|
5
|
+
# Different quantization levels trade off model size/speed vs accuracy
|
|
6
|
+
module Quantization
|
|
7
|
+
# Available quantization types
|
|
8
|
+
TYPES = {
|
|
9
|
+
fp32: {
|
|
10
|
+
suffix: '',
|
|
11
|
+
description: 'Full precision (32-bit float)',
|
|
12
|
+
size_multiplier: 1.0
|
|
13
|
+
},
|
|
14
|
+
fp16: {
|
|
15
|
+
suffix: '_fp16',
|
|
16
|
+
description: 'Half precision (16-bit float)',
|
|
17
|
+
size_multiplier: 0.5
|
|
18
|
+
},
|
|
19
|
+
int8: {
|
|
20
|
+
suffix: '_int8',
|
|
21
|
+
description: 'Dynamic int8 quantization',
|
|
22
|
+
size_multiplier: 0.25
|
|
23
|
+
},
|
|
24
|
+
uint8: {
|
|
25
|
+
suffix: '_uint8',
|
|
26
|
+
description: 'Dynamic uint8 quantization',
|
|
27
|
+
size_multiplier: 0.25
|
|
28
|
+
},
|
|
29
|
+
q4: {
|
|
30
|
+
suffix: '_q4',
|
|
31
|
+
description: '4-bit quantization',
|
|
32
|
+
size_multiplier: 0.125
|
|
33
|
+
}
|
|
34
|
+
}.freeze
|
|
35
|
+
|
|
36
|
+
# Default quantization type
|
|
37
|
+
DEFAULT = :fp32
|
|
38
|
+
|
|
39
|
+
class << self
|
|
40
|
+
# Check if a quantization type is valid
|
|
41
|
+
# @param type [Symbol] Quantization type
|
|
42
|
+
# @return [Boolean]
|
|
43
|
+
def valid?(type)
|
|
44
|
+
TYPES.key?(type)
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Get the model file suffix for a quantization type
|
|
48
|
+
# @param type [Symbol] Quantization type
|
|
49
|
+
# @return [String] Suffix to append to model filename
|
|
50
|
+
def suffix(type)
|
|
51
|
+
TYPES.dig(type, :suffix) || ''
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# Get quantized model filename
|
|
55
|
+
# @param base_file [String] Base model file (e.g., "onnx/model.onnx")
|
|
56
|
+
# @param type [Symbol] Quantization type
|
|
57
|
+
# @return [String] Quantized model filename
|
|
58
|
+
def model_file(base_file, type)
|
|
59
|
+
return base_file if type == :fp32 || type.nil?
|
|
60
|
+
|
|
61
|
+
ext = File.extname(base_file)
|
|
62
|
+
base = base_file.chomp(ext)
|
|
63
|
+
"#{base}#{suffix(type)}#{ext}"
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# List available quantization types
|
|
67
|
+
# @return [Array<Hash>] Array of quantization info
|
|
68
|
+
def list
|
|
69
|
+
TYPES.map do |type, info|
|
|
70
|
+
{ type: type, description: info[:description], size_multiplier: info[:size_multiplier] }
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
end
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Model information for reranker/cross-encoder models
|
|
5
|
+
#
|
|
6
|
+
# Cross-encoders process query-document pairs together rather than
|
|
7
|
+
# encoding them separately, enabling more accurate relevance scoring.
|
|
8
|
+
#
|
|
9
|
+
# @example Access reranker info
|
|
10
|
+
# info = Fastembed::SUPPORTED_RERANKER_MODELS['BAAI/bge-reranker-base']
|
|
11
|
+
# info.model_name # => "BAAI/bge-reranker-base"
|
|
12
|
+
#
|
|
13
|
+
class RerankerModelInfo
|
|
14
|
+
include BaseModelInfo
|
|
15
|
+
|
|
16
|
+
# Create a new RerankerModelInfo instance
|
|
17
|
+
#
|
|
18
|
+
# @param model_name [String] Full model identifier
|
|
19
|
+
# @param description [String] Human-readable description
|
|
20
|
+
# @param size_in_gb [Float] Model size in GB
|
|
21
|
+
# @param sources [Hash] Source repositories
|
|
22
|
+
# @param model_file [String] Path to ONNX model file
|
|
23
|
+
# @param tokenizer_file [String] Path to tokenizer file
|
|
24
|
+
def initialize(
|
|
25
|
+
model_name:,
|
|
26
|
+
description:,
|
|
27
|
+
size_in_gb:,
|
|
28
|
+
sources:,
|
|
29
|
+
model_file: 'onnx/model.onnx',
|
|
30
|
+
tokenizer_file: 'tokenizer.json'
|
|
31
|
+
)
|
|
32
|
+
initialize_base(
|
|
33
|
+
model_name: model_name,
|
|
34
|
+
description: description,
|
|
35
|
+
size_in_gb: size_in_gb,
|
|
36
|
+
sources: sources,
|
|
37
|
+
model_file: model_file,
|
|
38
|
+
tokenizer_file: tokenizer_file
|
|
39
|
+
)
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
# Convert to hash representation
|
|
43
|
+
# @return [Hash] Model info as a hash
|
|
44
|
+
def to_h
|
|
45
|
+
{
|
|
46
|
+
model_name: model_name,
|
|
47
|
+
description: description,
|
|
48
|
+
size_in_gb: size_in_gb,
|
|
49
|
+
sources: sources,
|
|
50
|
+
model_file: model_file,
|
|
51
|
+
tokenizer_file: tokenizer_file
|
|
52
|
+
}
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
# Registry of supported reranker models
|
|
57
|
+
SUPPORTED_RERANKER_MODELS = {
|
|
58
|
+
'BAAI/bge-reranker-base' => RerankerModelInfo.new(
|
|
59
|
+
model_name: 'BAAI/bge-reranker-base',
|
|
60
|
+
description: 'BGE reranker base model for relevance scoring',
|
|
61
|
+
size_in_gb: 1.11,
|
|
62
|
+
sources: { hf: 'Xenova/bge-reranker-base' }
|
|
63
|
+
),
|
|
64
|
+
'BAAI/bge-reranker-large' => RerankerModelInfo.new(
|
|
65
|
+
model_name: 'BAAI/bge-reranker-large',
|
|
66
|
+
description: 'BGE reranker large model for higher accuracy',
|
|
67
|
+
size_in_gb: 2.24,
|
|
68
|
+
sources: { hf: 'Xenova/bge-reranker-large' }
|
|
69
|
+
),
|
|
70
|
+
'cross-encoder/ms-marco-MiniLM-L-6-v2' => RerankerModelInfo.new(
|
|
71
|
+
model_name: 'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
|
72
|
+
description: 'Lightweight MS MARCO cross-encoder',
|
|
73
|
+
size_in_gb: 0.08,
|
|
74
|
+
sources: { hf: 'Xenova/ms-marco-MiniLM-L-6-v2' }
|
|
75
|
+
),
|
|
76
|
+
'cross-encoder/ms-marco-MiniLM-L-12-v2' => RerankerModelInfo.new(
|
|
77
|
+
model_name: 'cross-encoder/ms-marco-MiniLM-L-12-v2',
|
|
78
|
+
description: 'MS MARCO cross-encoder with better accuracy',
|
|
79
|
+
size_in_gb: 0.12,
|
|
80
|
+
sources: { hf: 'Xenova/ms-marco-MiniLM-L-12-v2' }
|
|
81
|
+
),
|
|
82
|
+
'jinaai/jina-reranker-v1-turbo-en' => RerankerModelInfo.new(
|
|
83
|
+
model_name: 'jinaai/jina-reranker-v1-turbo-en',
|
|
84
|
+
description: 'Fast Jina reranker for English',
|
|
85
|
+
size_in_gb: 0.13,
|
|
86
|
+
sources: { hf: 'jinaai/jina-reranker-v1-turbo-en' }
|
|
87
|
+
)
|
|
88
|
+
}.freeze
|
|
89
|
+
|
|
90
|
+
DEFAULT_RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
|
|
91
|
+
end
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Represents a sparse embedding vector
|
|
5
|
+
# Contains indices and their corresponding values (non-zero only)
|
|
6
|
+
class SparseEmbedding
|
|
7
|
+
attr_reader :indices, :values
|
|
8
|
+
|
|
9
|
+
def initialize(indices:, values:)
|
|
10
|
+
@indices = indices
|
|
11
|
+
@values = values
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
# Convert to hash format {index => value}
|
|
15
|
+
def to_h
|
|
16
|
+
indices.zip(values).to_h
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
# Number of non-zero elements
|
|
20
|
+
def nnz
|
|
21
|
+
indices.length
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def to_s
|
|
25
|
+
"SparseEmbedding(nnz=#{nnz})"
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def inspect
|
|
29
|
+
to_s
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Sparse text embedding using SPLADE models
|
|
34
|
+
#
|
|
35
|
+
# SPLADE (Sparse Lexical and Expansion) models produce sparse vectors
|
|
36
|
+
# where each dimension corresponds to a vocabulary token.
|
|
37
|
+
# These are useful for hybrid search combining with dense embeddings.
|
|
38
|
+
#
|
|
39
|
+
# @example Basic usage
|
|
40
|
+
# sparse = Fastembed::TextSparseEmbedding.new
|
|
41
|
+
# embeddings = sparse.embed(["Hello world"]).to_a
|
|
42
|
+
# embeddings.first.indices # => [101, 7592, 2088, ...]
|
|
43
|
+
# embeddings.first.values # => [0.45, 1.23, 0.89, ...]
|
|
44
|
+
#
|
|
45
|
+
class TextSparseEmbedding
|
|
46
|
+
include BaseModel
|
|
47
|
+
|
|
48
|
+
# Initialize a sparse embedding model
|
|
49
|
+
#
|
|
50
|
+
# @param model_name [String] Name of the model to use
|
|
51
|
+
# @param cache_dir [String, nil] Custom cache directory for models
|
|
52
|
+
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
53
|
+
# @param providers [Array<String>, nil] ONNX execution providers
|
|
54
|
+
# @param show_progress [Boolean] Whether to show download progress
|
|
55
|
+
# @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
|
|
56
|
+
# @param local_model_dir [String, nil] Load model from local directory instead of downloading
|
|
57
|
+
# @param model_file [String, nil] Override model file name (e.g., "model.onnx")
|
|
58
|
+
# @param tokenizer_file [String, nil] Override tokenizer file name (e.g., "tokenizer.json")
|
|
59
|
+
def initialize(
|
|
60
|
+
model_name: DEFAULT_SPARSE_MODEL,
|
|
61
|
+
cache_dir: nil,
|
|
62
|
+
threads: nil,
|
|
63
|
+
providers: nil,
|
|
64
|
+
show_progress: true,
|
|
65
|
+
quantization: nil,
|
|
66
|
+
local_model_dir: nil,
|
|
67
|
+
model_file: nil,
|
|
68
|
+
tokenizer_file: nil
|
|
69
|
+
)
|
|
70
|
+
if local_model_dir
|
|
71
|
+
initialize_from_local(
|
|
72
|
+
local_model_dir: local_model_dir,
|
|
73
|
+
model_name: model_name,
|
|
74
|
+
threads: threads,
|
|
75
|
+
providers: providers,
|
|
76
|
+
quantization: quantization,
|
|
77
|
+
model_file: model_file,
|
|
78
|
+
tokenizer_file: tokenizer_file
|
|
79
|
+
)
|
|
80
|
+
else
|
|
81
|
+
initialize_model(
|
|
82
|
+
model_name: model_name,
|
|
83
|
+
cache_dir: cache_dir,
|
|
84
|
+
threads: threads,
|
|
85
|
+
providers: providers,
|
|
86
|
+
show_progress: show_progress,
|
|
87
|
+
quantization: quantization
|
|
88
|
+
)
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
setup_model_and_tokenizer(model_file_override: model_file || quantized_model_file)
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Generate sparse embeddings for documents
|
|
95
|
+
#
|
|
96
|
+
# @param documents [Array<String>, String] Text document(s) to embed
|
|
97
|
+
# @param batch_size [Integer] Number of documents to process at once
|
|
98
|
+
# @yield [Progress] Optional progress callback called after each batch
|
|
99
|
+
# @return [Enumerator] Lazy enumerator yielding SparseEmbedding objects
|
|
100
|
+
def embed(documents, batch_size: 32, &progress_callback)
|
|
101
|
+
documents = Validators.validate_documents!(documents)
|
|
102
|
+
return Enumerator.new { |_| } if documents.empty?
|
|
103
|
+
|
|
104
|
+
total_batches = (documents.length.to_f / batch_size).ceil
|
|
105
|
+
|
|
106
|
+
Enumerator.new do |yielder|
|
|
107
|
+
documents.each_slice(batch_size).with_index(1) do |batch, batch_num|
|
|
108
|
+
sparse_embeddings = compute_sparse_embeddings(batch)
|
|
109
|
+
sparse_embeddings.each { |emb| yielder << emb }
|
|
110
|
+
|
|
111
|
+
if progress_callback
|
|
112
|
+
progress = Progress.new(current: batch_num, total: total_batches, batch_size: batch_size)
|
|
113
|
+
progress_callback.call(progress)
|
|
114
|
+
end
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
# Generate sparse embeddings for queries (with "query: " prefix)
|
|
120
|
+
#
|
|
121
|
+
# @param queries [Array<String>, String] Query text(s) to embed
|
|
122
|
+
# @param batch_size [Integer] Number of queries to process at once
|
|
123
|
+
# @return [Enumerator] Lazy enumerator yielding SparseEmbedding objects
|
|
124
|
+
def query_embed(queries, batch_size: 32)
|
|
125
|
+
queries = [queries] if queries.is_a?(String)
|
|
126
|
+
prefixed = queries.map { |q| "query: #{q}" }
|
|
127
|
+
embed(prefixed, batch_size: batch_size)
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Generate sparse embeddings for passages
|
|
131
|
+
#
|
|
132
|
+
# @param passages [Array<String>, String] Passage text(s) to embed
|
|
133
|
+
# @param batch_size [Integer] Number of passages to process at once
|
|
134
|
+
# @return [Enumerator] Lazy enumerator yielding SparseEmbedding objects
|
|
135
|
+
def passage_embed(passages, batch_size: 32)
|
|
136
|
+
passages = [passages] if passages.is_a?(String)
|
|
137
|
+
embed(passages, batch_size: batch_size)
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
# Generate sparse embeddings asynchronously
|
|
141
|
+
#
|
|
142
|
+
# @param documents [Array<String>, String] Text document(s) to embed
|
|
143
|
+
# @param batch_size [Integer] Number of documents to process at once
|
|
144
|
+
# @return [Async::Future] Future that resolves to array of SparseEmbedding objects
|
|
145
|
+
def embed_async(documents, batch_size: 32)
|
|
146
|
+
Async::Future.new { embed(documents, batch_size: batch_size).to_a }
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
# Generate query embeddings asynchronously
|
|
150
|
+
#
|
|
151
|
+
# @param queries [Array<String>, String] Query text(s) to embed
|
|
152
|
+
# @param batch_size [Integer] Number of queries to process at once
|
|
153
|
+
# @return [Async::Future] Future that resolves to array of SparseEmbedding objects
|
|
154
|
+
def query_embed_async(queries, batch_size: 32)
|
|
155
|
+
Async::Future.new { query_embed(queries, batch_size: batch_size).to_a }
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
# Generate passage embeddings asynchronously
|
|
159
|
+
#
|
|
160
|
+
# @param passages [Array<String>, String] Passage text(s) to embed
|
|
161
|
+
# @param batch_size [Integer] Number of passages to process at once
|
|
162
|
+
# @return [Async::Future] Future that resolves to array of SparseEmbedding objects
|
|
163
|
+
def passage_embed_async(passages, batch_size: 32)
|
|
164
|
+
Async::Future.new { passage_embed(passages, batch_size: batch_size).to_a }
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
# List all supported sparse models
|
|
168
|
+
#
|
|
169
|
+
# @return [Array<Hash>] Array of model information hashes
|
|
170
|
+
def self.list_supported_models
|
|
171
|
+
SUPPORTED_SPARSE_MODELS.values.map(&:to_h)
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
private
|
|
175
|
+
|
|
176
|
+
def resolve_model_info(model_name)
|
|
177
|
+
# Check built-in registry first
|
|
178
|
+
info = SUPPORTED_SPARSE_MODELS[model_name]
|
|
179
|
+
return info if info
|
|
180
|
+
|
|
181
|
+
# Check custom registry
|
|
182
|
+
info = CustomModelRegistry.sparse_models[model_name]
|
|
183
|
+
return info if info
|
|
184
|
+
|
|
185
|
+
raise Error, "Unknown sparse model: #{model_name}"
|
|
186
|
+
end
|
|
187
|
+
|
|
188
|
+
def create_local_model_info(model_name:, model_file:, tokenizer_file:)
|
|
189
|
+
SparseModelInfo.new(
|
|
190
|
+
model_name: model_name,
|
|
191
|
+
description: 'Local sparse model',
|
|
192
|
+
size_in_gb: 0,
|
|
193
|
+
sources: {},
|
|
194
|
+
model_file: model_file || 'model.onnx',
|
|
195
|
+
tokenizer_file: tokenizer_file || 'tokenizer.json'
|
|
196
|
+
)
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
def compute_sparse_embeddings(texts)
|
|
200
|
+
prepared = tokenize_and_prepare(texts)
|
|
201
|
+
outputs = @session.run(nil, prepared[:inputs])
|
|
202
|
+
logits = extract_logits(outputs)
|
|
203
|
+
|
|
204
|
+
# Convert to sparse embeddings
|
|
205
|
+
texts.length.times.map do |i|
|
|
206
|
+
create_sparse_embedding(logits[i], prepared[:attention_mask][i])
|
|
207
|
+
end
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
def extract_logits(outputs)
|
|
211
|
+
# SPLADE outputs logits for each token position
|
|
212
|
+
if outputs.is_a?(Hash)
|
|
213
|
+
outputs['logits'] || outputs.values.first
|
|
214
|
+
else
|
|
215
|
+
outputs.first
|
|
216
|
+
end
|
|
217
|
+
end
|
|
218
|
+
|
|
219
|
+
# Transform token logits into sparse embedding using SPLADE algorithm
|
|
220
|
+
#
|
|
221
|
+
# SPLADE (Sparse Lexical AnD Expansion) uses a log-saturation function:
|
|
222
|
+
# weight = log(1 + ReLU(logit))
|
|
223
|
+
#
|
|
224
|
+
# This ensures:
|
|
225
|
+
# - ReLU removes negative activations (irrelevant terms)
|
|
226
|
+
# - log1p provides saturation to prevent any term from dominating
|
|
227
|
+
# - Max-pooling across positions captures the strongest signal per vocabulary term
|
|
228
|
+
#
|
|
229
|
+
# @param token_logits [Array<Array<Float>>] Logits for each token position
|
|
230
|
+
# @param attention_mask [Array<Integer>] Mask indicating valid positions
|
|
231
|
+
# @return [SparseEmbedding] Sparse vector with vocabulary indices and weights
|
|
232
|
+
def create_sparse_embedding(token_logits, attention_mask)
|
|
233
|
+
vocab_size = token_logits.first.length
|
|
234
|
+
max_weights = Array.new(vocab_size, 0.0)
|
|
235
|
+
|
|
236
|
+
token_logits.each_with_index do |logits, pos|
|
|
237
|
+
next if attention_mask[pos].zero?
|
|
238
|
+
|
|
239
|
+
logits.each_with_index do |logit, vocab_idx|
|
|
240
|
+
# SPLADE transformation: log(1 + ReLU(x))
|
|
241
|
+
activated = logit.positive? ? logit : 0.0
|
|
242
|
+
weight = Math.log(1.0 + activated)
|
|
243
|
+
max_weights[vocab_idx] = weight if weight > max_weights[vocab_idx]
|
|
244
|
+
end
|
|
245
|
+
end
|
|
246
|
+
|
|
247
|
+
# Extract non-zero indices and values
|
|
248
|
+
indices = []
|
|
249
|
+
values = []
|
|
250
|
+
|
|
251
|
+
max_weights.each_with_index do |weight, idx|
|
|
252
|
+
if weight.positive?
|
|
253
|
+
indices << idx
|
|
254
|
+
values << weight
|
|
255
|
+
end
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
SparseEmbedding.new(indices: indices, values: values)
|
|
259
|
+
end
|
|
260
|
+
end
|
|
261
|
+
end
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Model information for sparse embedding models (SPLADE, etc.)
|
|
5
|
+
#
|
|
6
|
+
# Sparse models like SPLADE produce vectors where most dimensions are zero,
|
|
7
|
+
# with non-zero values corresponding to vocabulary tokens. These are useful
|
|
8
|
+
# for hybrid search combining with dense embeddings.
|
|
9
|
+
#
|
|
10
|
+
# @example Access sparse model info
|
|
11
|
+
# info = Fastembed::SUPPORTED_SPARSE_MODELS['prithivida/Splade_PP_en_v1']
|
|
12
|
+
# info.model_name # => "prithivida/Splade_PP_en_v1"
|
|
13
|
+
#
|
|
14
|
+
class SparseModelInfo
|
|
15
|
+
include BaseModelInfo
|
|
16
|
+
|
|
17
|
+
# Create a new SparseModelInfo instance
|
|
18
|
+
#
|
|
19
|
+
# @param model_name [String] Full model identifier
|
|
20
|
+
# @param description [String] Human-readable description
|
|
21
|
+
# @param size_in_gb [Float] Model size in GB
|
|
22
|
+
# @param sources [Hash] Source repositories
|
|
23
|
+
# @param model_file [String] Path to ONNX model file
|
|
24
|
+
# @param tokenizer_file [String] Path to tokenizer file
|
|
25
|
+
# @param max_length [Integer] Maximum sequence length
|
|
26
|
+
def initialize(
|
|
27
|
+
model_name:,
|
|
28
|
+
description:,
|
|
29
|
+
size_in_gb:,
|
|
30
|
+
sources:,
|
|
31
|
+
model_file: 'onnx/model.onnx',
|
|
32
|
+
tokenizer_file: 'tokenizer.json',
|
|
33
|
+
max_length: BaseModelInfo::DEFAULT_MAX_LENGTH
|
|
34
|
+
)
|
|
35
|
+
initialize_base(
|
|
36
|
+
model_name: model_name,
|
|
37
|
+
description: description,
|
|
38
|
+
size_in_gb: size_in_gb,
|
|
39
|
+
sources: sources,
|
|
40
|
+
model_file: model_file,
|
|
41
|
+
tokenizer_file: tokenizer_file,
|
|
42
|
+
max_length: max_length
|
|
43
|
+
)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
# Convert to hash representation
|
|
47
|
+
# @return [Hash] Model info as a hash
|
|
48
|
+
def to_h
|
|
49
|
+
{
|
|
50
|
+
model_name: model_name,
|
|
51
|
+
description: description,
|
|
52
|
+
size_in_gb: size_in_gb,
|
|
53
|
+
sources: sources,
|
|
54
|
+
model_file: model_file,
|
|
55
|
+
tokenizer_file: tokenizer_file,
|
|
56
|
+
max_length: max_length
|
|
57
|
+
}
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# Registry of supported sparse embedding models
|
|
62
|
+
SUPPORTED_SPARSE_MODELS = {
|
|
63
|
+
'prithivida/Splade_PP_en_v1' => SparseModelInfo.new(
|
|
64
|
+
model_name: 'prithivida/Splade_PP_en_v1',
|
|
65
|
+
description: 'SPLADE++ model for sparse text retrieval',
|
|
66
|
+
size_in_gb: 0.53,
|
|
67
|
+
sources: { hf: 'prithivida/Splade_PP_en_v1' }
|
|
68
|
+
# Uses default model_file: 'onnx/model.onnx'
|
|
69
|
+
),
|
|
70
|
+
'prithivida/Splade_PP_en_v2' => SparseModelInfo.new(
|
|
71
|
+
model_name: 'prithivida/Splade_PP_en_v2',
|
|
72
|
+
description: 'SPLADE++ v2 with improved performance',
|
|
73
|
+
size_in_gb: 0.53,
|
|
74
|
+
sources: { hf: 'prithivida/Splade_PP_en_v2' }
|
|
75
|
+
# Uses default model_file: 'onnx/model.onnx'
|
|
76
|
+
)
|
|
77
|
+
}.freeze
|
|
78
|
+
|
|
79
|
+
DEFAULT_SPARSE_MODEL = 'prithivida/Splade_PP_en_v1'
|
|
80
|
+
end
|