fastembed 1.0.0 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +1 -0
- data/.yardopts +6 -0
- data/BENCHMARKS.md +124 -1
- data/CHANGELOG.md +14 -0
- data/README.md +395 -74
- data/benchmark/compare_all.rb +167 -0
- data/benchmark/compare_python.py +60 -0
- data/benchmark/memory_profile.rb +70 -0
- data/benchmark/profile.rb +198 -0
- data/benchmark/reranker_benchmark.rb +158 -0
- data/exe/fastembed +6 -0
- data/fastembed.gemspec +3 -0
- data/lib/fastembed/async.rb +193 -0
- data/lib/fastembed/base_model.rb +247 -0
- data/lib/fastembed/base_model_info.rb +61 -0
- data/lib/fastembed/cli.rb +745 -0
- data/lib/fastembed/custom_model_registry.rb +255 -0
- data/lib/fastembed/image_embedding.rb +313 -0
- data/lib/fastembed/late_interaction_embedding.rb +260 -0
- data/lib/fastembed/late_interaction_model_info.rb +91 -0
- data/lib/fastembed/model_info.rb +59 -19
- data/lib/fastembed/model_management.rb +82 -23
- data/lib/fastembed/onnx_embedding_model.rb +25 -4
- data/lib/fastembed/pooling.rb +39 -3
- data/lib/fastembed/progress.rb +52 -0
- data/lib/fastembed/quantization.rb +75 -0
- data/lib/fastembed/reranker_model_info.rb +91 -0
- data/lib/fastembed/sparse_embedding.rb +261 -0
- data/lib/fastembed/sparse_model_info.rb +80 -0
- data/lib/fastembed/text_cross_encoder.rb +217 -0
- data/lib/fastembed/text_embedding.rb +161 -28
- data/lib/fastembed/validators.rb +59 -0
- data/lib/fastembed/version.rb +1 -1
- data/lib/fastembed.rb +42 -1
- data/plan.md +257 -0
- data/scripts/verify_models.rb +229 -0
- metadata +70 -3
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Cross-encoder model for reranking query-document pairs
|
|
5
|
+
#
|
|
6
|
+
# Unlike embedding models that encode texts independently, cross-encoders
|
|
7
|
+
# process query-document pairs together to produce relevance scores.
|
|
8
|
+
# This is more accurate but slower (O(n) comparisons vs O(1) with embeddings).
|
|
9
|
+
#
|
|
10
|
+
# @example Basic reranking
|
|
11
|
+
# reranker = Fastembed::TextCrossEncoder.new
|
|
12
|
+
# scores = reranker.rerank(
|
|
13
|
+
# query: "What is machine learning?",
|
|
14
|
+
# documents: ["ML is a subset of AI...", "The weather is nice today"]
|
|
15
|
+
# )
|
|
16
|
+
# # => [0.95, 0.02]
|
|
17
|
+
#
|
|
18
|
+
# @example Get ranked results
|
|
19
|
+
# results = reranker.rerank_with_scores(
|
|
20
|
+
# query: "What is Ruby?",
|
|
21
|
+
# documents: documents
|
|
22
|
+
# )
|
|
23
|
+
# # => [{document: "Ruby is a programming...", score: 0.89, index: 2}, ...]
|
|
24
|
+
#
|
|
25
|
+
class TextCrossEncoder
|
|
26
|
+
include BaseModel
|
|
27
|
+
|
|
28
|
+
# Initialize a cross-encoder model for reranking
|
|
29
|
+
#
|
|
30
|
+
# @param model_name [String] Name of the model to use
|
|
31
|
+
# @param cache_dir [String, nil] Custom cache directory for models
|
|
32
|
+
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
33
|
+
# @param providers [Array<String>, nil] ONNX execution providers
|
|
34
|
+
# @param show_progress [Boolean] Whether to show download progress
|
|
35
|
+
# @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
|
|
36
|
+
# @param local_model_dir [String, nil] Load model from local directory instead of downloading
|
|
37
|
+
# @param model_file [String, nil] Override model file name (e.g., "model.onnx")
|
|
38
|
+
# @param tokenizer_file [String, nil] Override tokenizer file name (e.g., "tokenizer.json")
|
|
39
|
+
def initialize(
|
|
40
|
+
model_name: DEFAULT_RERANKER_MODEL,
|
|
41
|
+
cache_dir: nil,
|
|
42
|
+
threads: nil,
|
|
43
|
+
providers: nil,
|
|
44
|
+
show_progress: true,
|
|
45
|
+
quantization: nil,
|
|
46
|
+
local_model_dir: nil,
|
|
47
|
+
model_file: nil,
|
|
48
|
+
tokenizer_file: nil
|
|
49
|
+
)
|
|
50
|
+
if local_model_dir
|
|
51
|
+
initialize_from_local(
|
|
52
|
+
local_model_dir: local_model_dir,
|
|
53
|
+
model_name: model_name,
|
|
54
|
+
threads: threads,
|
|
55
|
+
providers: providers,
|
|
56
|
+
quantization: quantization,
|
|
57
|
+
model_file: model_file,
|
|
58
|
+
tokenizer_file: tokenizer_file
|
|
59
|
+
)
|
|
60
|
+
else
|
|
61
|
+
initialize_model(
|
|
62
|
+
model_name: model_name,
|
|
63
|
+
cache_dir: cache_dir,
|
|
64
|
+
threads: threads,
|
|
65
|
+
providers: providers,
|
|
66
|
+
show_progress: show_progress,
|
|
67
|
+
quantization: quantization
|
|
68
|
+
)
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
setup_model_and_tokenizer(model_file_override: model_file || quantized_model_file)
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Score query-document pairs and return relevance scores
|
|
75
|
+
#
|
|
76
|
+
# @param query [String] The query text
|
|
77
|
+
# @param documents [Array<String>] Documents to score against the query
|
|
78
|
+
# @param batch_size [Integer] Number of pairs to process at once
|
|
79
|
+
# @return [Array<Float>] Relevance scores for each document (higher = more relevant)
|
|
80
|
+
# @raise [ArgumentError] If query or documents is nil, or documents contains nil
|
|
81
|
+
def rerank(query:, documents:, batch_size: 64)
|
|
82
|
+
Validators.validate_rerank_input!(query: query, documents: documents)
|
|
83
|
+
return [] if documents.empty?
|
|
84
|
+
|
|
85
|
+
scores = []
|
|
86
|
+
documents.each_slice(batch_size) do |batch|
|
|
87
|
+
batch_scores = score_pairs(query, batch)
|
|
88
|
+
scores.concat(batch_scores)
|
|
89
|
+
end
|
|
90
|
+
scores
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
# Rerank documents and return sorted results with scores
|
|
94
|
+
#
|
|
95
|
+
# @param query [String] The query text
|
|
96
|
+
# @param documents [Array<String>] Documents to rerank
|
|
97
|
+
# @param top_k [Integer, nil] Return only top K results (nil = all)
|
|
98
|
+
# @param batch_size [Integer] Number of pairs to process at once
|
|
99
|
+
# @return [Array<Hash>] Sorted results with :document, :score, :index keys
|
|
100
|
+
def rerank_with_scores(query:, documents:, top_k: nil, batch_size: 64)
|
|
101
|
+
scores = rerank(query: query, documents: documents, batch_size: batch_size)
|
|
102
|
+
|
|
103
|
+
results = documents.zip(scores).each_with_index.map do |(doc, score), idx|
|
|
104
|
+
{ document: doc, score: score, index: idx }
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
results.sort_by! { |r| -r[:score] }
|
|
108
|
+
top_k ? results.first(top_k) : results
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
# Rerank documents asynchronously
|
|
112
|
+
#
|
|
113
|
+
# @param query [String] The query text
|
|
114
|
+
# @param documents [Array<String>] Documents to score against the query
|
|
115
|
+
# @param batch_size [Integer] Number of pairs to process at once
|
|
116
|
+
# @return [Async::Future] Future that resolves to array of scores
|
|
117
|
+
def rerank_async(query:, documents:, batch_size: 64)
|
|
118
|
+
Async::Future.new { rerank(query: query, documents: documents, batch_size: batch_size) }
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
# Rerank documents with scores asynchronously
|
|
122
|
+
#
|
|
123
|
+
# @param query [String] The query text
|
|
124
|
+
# @param documents [Array<String>] Documents to rerank
|
|
125
|
+
# @param top_k [Integer, nil] Return only top K results (nil = all)
|
|
126
|
+
# @param batch_size [Integer] Number of pairs to process at once
|
|
127
|
+
# @return [Async::Future] Future that resolves to sorted results array
|
|
128
|
+
def rerank_with_scores_async(query:, documents:, top_k: nil, batch_size: 64)
|
|
129
|
+
Async::Future.new { rerank_with_scores(query: query, documents: documents, top_k: top_k, batch_size: batch_size) }
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
# List all supported reranker models
|
|
133
|
+
#
|
|
134
|
+
# @return [Array<Hash>] Array of model information hashes
|
|
135
|
+
def self.list_supported_models
|
|
136
|
+
SUPPORTED_RERANKER_MODELS.values.map(&:to_h)
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
private
|
|
140
|
+
|
|
141
|
+
def resolve_model_info(model_name)
|
|
142
|
+
# Check built-in registry first
|
|
143
|
+
info = SUPPORTED_RERANKER_MODELS[model_name]
|
|
144
|
+
return info if info
|
|
145
|
+
|
|
146
|
+
# Check custom registry
|
|
147
|
+
info = CustomModelRegistry.reranker_models[model_name]
|
|
148
|
+
return info if info
|
|
149
|
+
|
|
150
|
+
raise Error, "Unknown reranker model: #{model_name}"
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
def create_local_model_info(model_name:, model_file:, tokenizer_file:)
|
|
154
|
+
RerankerModelInfo.new(
|
|
155
|
+
model_name: model_name,
|
|
156
|
+
description: 'Local reranker model',
|
|
157
|
+
size_in_gb: 0,
|
|
158
|
+
sources: {},
|
|
159
|
+
model_file: model_file || 'model.onnx',
|
|
160
|
+
tokenizer_file: tokenizer_file || 'tokenizer.json'
|
|
161
|
+
)
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def score_pairs(query, documents)
|
|
165
|
+
encodings = tokenize_pairs(query, documents)
|
|
166
|
+
inputs = prepare_pair_inputs(encodings)
|
|
167
|
+
extract_scores(@session.run(nil, inputs))
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
def tokenize_pairs(query, documents)
|
|
171
|
+
documents.map { |doc| @tokenizer.encode(query, doc) }
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
def prepare_pair_inputs(encodings)
|
|
175
|
+
max_len = encodings.map { |e| e.ids.length }.max
|
|
176
|
+
|
|
177
|
+
input_ids = []
|
|
178
|
+
attention_mask = []
|
|
179
|
+
token_type_ids = []
|
|
180
|
+
|
|
181
|
+
encodings.each do |encoding|
|
|
182
|
+
pad_len = max_len - encoding.ids.length
|
|
183
|
+
input_ids << pad_sequence(encoding.ids, pad_len)
|
|
184
|
+
attention_mask << pad_sequence(encoding.attention_mask, pad_len)
|
|
185
|
+
token_type_ids << pad_sequence(encoding.type_ids, pad_len)
|
|
186
|
+
end
|
|
187
|
+
|
|
188
|
+
inputs = { 'input_ids' => input_ids }
|
|
189
|
+
|
|
190
|
+
# Only add attention_mask if the model expects it
|
|
191
|
+
mask_key = if session_input_names.include?('attention_mask')
|
|
192
|
+
'attention_mask'
|
|
193
|
+
elsif session_input_names.include?('input_mask')
|
|
194
|
+
'input_mask'
|
|
195
|
+
end
|
|
196
|
+
inputs[mask_key] = attention_mask if mask_key
|
|
197
|
+
|
|
198
|
+
# Only add token_type_ids if the model expects it
|
|
199
|
+
type_key = if session_input_names.include?('token_type_ids')
|
|
200
|
+
'token_type_ids'
|
|
201
|
+
elsif session_input_names.include?('segment_ids')
|
|
202
|
+
'segment_ids'
|
|
203
|
+
end
|
|
204
|
+
inputs[type_key] = token_type_ids if type_key
|
|
205
|
+
|
|
206
|
+
inputs
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
def pad_sequence(sequence, pad_len)
|
|
210
|
+
sequence + ([0] * pad_len)
|
|
211
|
+
end
|
|
212
|
+
|
|
213
|
+
def extract_scores(outputs)
|
|
214
|
+
outputs.first.map { |logit| logit.is_a?(Array) ? logit.first : logit }
|
|
215
|
+
end
|
|
216
|
+
end
|
|
217
|
+
end
|
|
@@ -15,8 +15,17 @@ module Fastembed
|
|
|
15
15
|
# # Process each vector
|
|
16
16
|
# end
|
|
17
17
|
#
|
|
18
|
+
# @example Load from local directory
|
|
19
|
+
# embedding = Fastembed::TextEmbedding.new(
|
|
20
|
+
# local_model_dir: "/path/to/model",
|
|
21
|
+
# model_file: "model.onnx",
|
|
22
|
+
# tokenizer_file: "tokenizer.json"
|
|
23
|
+
# )
|
|
24
|
+
#
|
|
18
25
|
class TextEmbedding
|
|
19
|
-
|
|
26
|
+
include BaseModel
|
|
27
|
+
|
|
28
|
+
attr_reader :dim
|
|
20
29
|
|
|
21
30
|
# Initialize a text embedding model
|
|
22
31
|
#
|
|
@@ -25,55 +34,84 @@ module Fastembed
|
|
|
25
34
|
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
26
35
|
# @param providers [Array<String>, nil] ONNX execution providers (e.g., ["CoreMLExecutionProvider"])
|
|
27
36
|
# @param show_progress [Boolean] Whether to show download progress
|
|
37
|
+
# @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
|
|
38
|
+
# @param local_model_dir [String, nil] Load model from local directory instead of downloading
|
|
39
|
+
# @param model_file [String, nil] Override model file name (e.g., "model.onnx")
|
|
40
|
+
# @param tokenizer_file [String, nil] Override tokenizer file name (e.g., "tokenizer.json")
|
|
28
41
|
def initialize(
|
|
29
42
|
model_name: DEFAULT_MODEL,
|
|
30
43
|
cache_dir: nil,
|
|
31
44
|
threads: nil,
|
|
32
45
|
providers: nil,
|
|
33
|
-
show_progress: true
|
|
46
|
+
show_progress: true,
|
|
47
|
+
quantization: nil,
|
|
48
|
+
local_model_dir: nil,
|
|
49
|
+
model_file: nil,
|
|
50
|
+
tokenizer_file: nil
|
|
34
51
|
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
52
|
+
if local_model_dir
|
|
53
|
+
initialize_from_local(
|
|
54
|
+
local_model_dir: local_model_dir,
|
|
55
|
+
model_name: model_name,
|
|
56
|
+
threads: threads,
|
|
57
|
+
providers: providers,
|
|
58
|
+
quantization: quantization,
|
|
59
|
+
model_file: model_file,
|
|
60
|
+
tokenizer_file: tokenizer_file
|
|
61
|
+
)
|
|
62
|
+
else
|
|
63
|
+
initialize_model(
|
|
64
|
+
model_name: model_name,
|
|
65
|
+
cache_dir: cache_dir,
|
|
66
|
+
threads: threads,
|
|
67
|
+
providers: providers,
|
|
68
|
+
show_progress: show_progress,
|
|
69
|
+
quantization: quantization
|
|
70
|
+
)
|
|
71
|
+
end
|
|
42
72
|
|
|
43
|
-
# Resolve model info
|
|
44
|
-
@model_info = ModelManagement.resolve_model_info(model_name)
|
|
45
73
|
@dim = @model_info.dim
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
74
|
+
@model = OnnxEmbeddingModel.new(
|
|
75
|
+
@model_info,
|
|
76
|
+
@model_dir,
|
|
77
|
+
threads: threads,
|
|
78
|
+
providers: providers,
|
|
79
|
+
model_file_override: model_file || quantized_model_file
|
|
80
|
+
)
|
|
50
81
|
end
|
|
51
82
|
|
|
52
83
|
# Generate embeddings for documents
|
|
53
84
|
#
|
|
54
85
|
# @param documents [Array<String>, String] Text document(s) to embed
|
|
55
86
|
# @param batch_size [Integer] Number of documents to process at once
|
|
87
|
+
# @yield [Progress] Optional progress callback called after each batch
|
|
56
88
|
# @return [Enumerator] Lazy enumerator yielding embedding vectors
|
|
57
89
|
# @raise [ArgumentError] If documents is nil or contains nil values
|
|
58
90
|
#
|
|
59
|
-
# @example
|
|
91
|
+
# @example Basic usage
|
|
60
92
|
# vectors = embedding.embed(["Hello", "World"]).to_a
|
|
61
93
|
# # => [[0.1, 0.2, ...], [0.3, 0.4, ...]]
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
94
|
+
#
|
|
95
|
+
# @example With progress callback
|
|
96
|
+
# embedding.embed(documents, batch_size: 64) do |progress|
|
|
97
|
+
# puts "#{progress.percent}% complete"
|
|
98
|
+
# end.to_a
|
|
99
|
+
#
|
|
100
|
+
def embed(documents, batch_size: 256, &progress_callback)
|
|
101
|
+
documents = Validators.validate_documents!(documents)
|
|
66
102
|
return Enumerator.new { |_| } if documents.empty?
|
|
67
103
|
|
|
68
|
-
|
|
69
|
-
documents.each_with_index do |doc, i|
|
|
70
|
-
raise ArgumentError, "document at index #{i} cannot be nil" if doc.nil?
|
|
71
|
-
end
|
|
104
|
+
total_batches = (documents.length.to_f / batch_size).ceil
|
|
72
105
|
|
|
73
106
|
Enumerator.new do |yielder|
|
|
74
|
-
documents.each_slice(batch_size) do |batch|
|
|
107
|
+
documents.each_slice(batch_size).with_index(1) do |batch, batch_num|
|
|
75
108
|
embeddings = @model.embed(batch)
|
|
76
109
|
embeddings.each { |embedding| yielder << embedding }
|
|
110
|
+
|
|
111
|
+
if progress_callback
|
|
112
|
+
progress = Progress.new(current: batch_num, total: total_batches, batch_size: batch_size)
|
|
113
|
+
progress_callback.call(progress)
|
|
114
|
+
end
|
|
77
115
|
end
|
|
78
116
|
end
|
|
79
117
|
end
|
|
@@ -100,11 +138,60 @@ module Fastembed
|
|
|
100
138
|
embed(prefixed, batch_size: batch_size)
|
|
101
139
|
end
|
|
102
140
|
|
|
103
|
-
#
|
|
141
|
+
# Generate embeddings asynchronously in a background thread
|
|
142
|
+
#
|
|
143
|
+
# Returns immediately with a Future object. The embedding computation
|
|
144
|
+
# runs in a background thread. Call `value` on the Future to get results.
|
|
145
|
+
#
|
|
146
|
+
# @param documents [Array<String>, String] Text document(s) to embed
|
|
147
|
+
# @param batch_size [Integer] Number of documents to process at once
|
|
148
|
+
# @return [Async::Future] Future that resolves to array of embedding vectors
|
|
149
|
+
#
|
|
150
|
+
# @example Basic async usage
|
|
151
|
+
# future = embedding.embed_async(documents)
|
|
152
|
+
# # ... do other work ...
|
|
153
|
+
# vectors = future.value # blocks until complete
|
|
154
|
+
#
|
|
155
|
+
# @example Parallel embedding of multiple batches
|
|
156
|
+
# futures = documents.each_slice(1000).map do |batch|
|
|
157
|
+
# embedding.embed_async(batch)
|
|
158
|
+
# end
|
|
159
|
+
# all_vectors = futures.flat_map(&:value)
|
|
160
|
+
#
|
|
161
|
+
def embed_async(documents, batch_size: 256)
|
|
162
|
+
Async::Future.new do
|
|
163
|
+
embed(documents, batch_size: batch_size).to_a
|
|
164
|
+
end
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
# Generate query embeddings asynchronously
|
|
168
|
+
#
|
|
169
|
+
# @param queries [Array<String>, String] Query text(s) to embed
|
|
170
|
+
# @param batch_size [Integer] Number of queries to process at once
|
|
171
|
+
# @return [Async::Future] Future that resolves to array of embedding vectors
|
|
172
|
+
def query_embed_async(queries, batch_size: 256)
|
|
173
|
+
Async::Future.new do
|
|
174
|
+
query_embed(queries, batch_size: batch_size).to_a
|
|
175
|
+
end
|
|
176
|
+
end
|
|
177
|
+
|
|
178
|
+
# Generate passage embeddings asynchronously
|
|
179
|
+
#
|
|
180
|
+
# @param passages [Array<String>, String] Passage text(s) to embed
|
|
181
|
+
# @param batch_size [Integer] Number of passages to process at once
|
|
182
|
+
# @return [Async::Future] Future that resolves to array of embedding vectors
|
|
183
|
+
def passage_embed_async(passages, batch_size: 256)
|
|
184
|
+
Async::Future.new do
|
|
185
|
+
passage_embed(passages, batch_size: batch_size).to_a
|
|
186
|
+
end
|
|
187
|
+
end
|
|
188
|
+
|
|
189
|
+
# List all supported models (built-in and custom)
|
|
104
190
|
#
|
|
105
191
|
# @return [Array<Hash>] Array of model information hashes
|
|
106
192
|
def self.list_supported_models
|
|
107
|
-
SUPPORTED_MODELS.
|
|
193
|
+
all_models = SUPPORTED_MODELS.merge(CustomModelRegistry.embedding_models)
|
|
194
|
+
all_models.values.map(&:to_h)
|
|
108
195
|
end
|
|
109
196
|
|
|
110
197
|
# Get information about a specific model
|
|
@@ -112,7 +199,53 @@ module Fastembed
|
|
|
112
199
|
# @param model_name [String] Name of the model
|
|
113
200
|
# @return [Hash, nil] Model information or nil if not found
|
|
114
201
|
def self.get_model_info(model_name)
|
|
115
|
-
SUPPORTED_MODELS[model_name]
|
|
202
|
+
info = SUPPORTED_MODELS[model_name] || CustomModelRegistry.embedding_models[model_name]
|
|
203
|
+
info&.to_h
|
|
204
|
+
end
|
|
205
|
+
|
|
206
|
+
private
|
|
207
|
+
|
|
208
|
+
def resolve_model_info(model_name)
|
|
209
|
+
ModelManagement.resolve_model_info(model_name)
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
def create_local_model_info(model_name:, model_file:, tokenizer_file:)
|
|
213
|
+
# Detect dimension from model output shape if possible
|
|
214
|
+
# For now, use a placeholder that will be updated after model load
|
|
215
|
+
ModelInfo.new(
|
|
216
|
+
model_name: model_name,
|
|
217
|
+
dim: detect_model_dimension(model_file) || 384,
|
|
218
|
+
description: 'Local model',
|
|
219
|
+
size_in_gb: 0,
|
|
220
|
+
sources: {},
|
|
221
|
+
model_file: model_file || 'model.onnx',
|
|
222
|
+
tokenizer_file: tokenizer_file || 'tokenizer.json'
|
|
223
|
+
)
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
# Detect embedding dimension from ONNX model output shape
|
|
227
|
+
#
|
|
228
|
+
# @param model_file [String, nil] Model filename to inspect
|
|
229
|
+
# @return [Integer, nil] Detected dimension or nil if detection fails
|
|
230
|
+
def detect_model_dimension(model_file)
|
|
231
|
+
model_path = File.join(@model_dir, model_file || 'model.onnx')
|
|
232
|
+
return nil unless File.exist?(model_path)
|
|
233
|
+
|
|
234
|
+
session = OnnxRuntime::InferenceSession.new(model_path)
|
|
235
|
+
# Look for output shape - usually [batch, seq_len, hidden_size] or [batch, hidden_size]
|
|
236
|
+
output = session.outputs.first
|
|
237
|
+
return nil unless output && output[:shape]
|
|
238
|
+
|
|
239
|
+
shape = output[:shape]
|
|
240
|
+
# Last dimension is usually the embedding dimension
|
|
241
|
+
dim = shape.last
|
|
242
|
+
dim.is_a?(Integer) && dim.positive? ? dim : nil
|
|
243
|
+
rescue OnnxRuntime::Error => e
|
|
244
|
+
warn "Warning: Could not detect model dimension: #{e.message}"
|
|
245
|
+
nil
|
|
246
|
+
rescue StandardError => e
|
|
247
|
+
warn "Warning: Unexpected error detecting model dimension: #{e.class} - #{e.message}"
|
|
248
|
+
nil
|
|
116
249
|
end
|
|
117
250
|
end
|
|
118
251
|
end
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Input validation helpers for embedding models
|
|
5
|
+
#
|
|
6
|
+
# Provides consistent validation and normalization of documents, queries,
|
|
7
|
+
# and other inputs across all model types.
|
|
8
|
+
#
|
|
9
|
+
# @api private
|
|
10
|
+
#
|
|
11
|
+
module Validators
|
|
12
|
+
class << self
|
|
13
|
+
# Validate and normalize document input
|
|
14
|
+
#
|
|
15
|
+
# Ensures documents are not nil, converts single strings to arrays,
|
|
16
|
+
# and validates that no individual document is nil.
|
|
17
|
+
#
|
|
18
|
+
# @param documents [Array<String>, String, nil] Documents to validate
|
|
19
|
+
# @return [Array<String>] Normalized array of documents
|
|
20
|
+
# @raise [ArgumentError] If documents is nil or contains nil values
|
|
21
|
+
def validate_documents!(documents)
|
|
22
|
+
raise ArgumentError, 'documents cannot be nil' if documents.nil?
|
|
23
|
+
|
|
24
|
+
documents = [documents] if documents.is_a?(String)
|
|
25
|
+
|
|
26
|
+
documents.each_with_index do |doc, i|
|
|
27
|
+
raise ArgumentError, "document at index #{i} cannot be nil" if doc.nil?
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
documents
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Validate query and documents for reranking
|
|
34
|
+
#
|
|
35
|
+
# @param query [String, nil] Query to validate
|
|
36
|
+
# @param documents [Array<String>, nil] Documents to validate
|
|
37
|
+
# @return [Array<String>] Validated documents array
|
|
38
|
+
# @raise [ArgumentError] If query or documents is nil, or documents contains nil
|
|
39
|
+
def validate_rerank_input!(query:, documents:)
|
|
40
|
+
raise ArgumentError, 'query cannot be nil' if query.nil?
|
|
41
|
+
raise ArgumentError, 'documents cannot be nil' if documents.nil?
|
|
42
|
+
|
|
43
|
+
documents.each_with_index do |doc, i|
|
|
44
|
+
raise ArgumentError, "document at index #{i} cannot be nil" if doc.nil?
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
documents
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Check if documents array is empty
|
|
51
|
+
#
|
|
52
|
+
# @param documents [Array<String>] Documents to check
|
|
53
|
+
# @return [Boolean] True if empty
|
|
54
|
+
def empty?(documents)
|
|
55
|
+
documents.empty?
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
end
|
data/lib/fastembed/version.rb
CHANGED
data/lib/fastembed.rb
CHANGED
|
@@ -2,13 +2,54 @@
|
|
|
2
2
|
|
|
3
3
|
require_relative 'fastembed/version'
|
|
4
4
|
|
|
5
|
+
# Fastembed - Fast, lightweight text embeddings for Ruby
|
|
6
|
+
#
|
|
7
|
+
# A Ruby port of FastEmbed providing text embeddings using ONNX Runtime.
|
|
8
|
+
# Supports dense embeddings, sparse embeddings (SPLADE), late interaction (ColBERT),
|
|
9
|
+
# and cross-encoder reranking.
|
|
10
|
+
#
|
|
11
|
+
# @example Basic text embedding
|
|
12
|
+
# embedding = Fastembed::TextEmbedding.new
|
|
13
|
+
# vectors = embedding.embed(["Hello world", "Ruby is great"]).to_a
|
|
14
|
+
#
|
|
15
|
+
# @example Reranking documents
|
|
16
|
+
# reranker = Fastembed::TextCrossEncoder.new
|
|
17
|
+
# scores = reranker.rerank(query: "What is ML?", documents: docs)
|
|
18
|
+
#
|
|
19
|
+
# @example Sparse embeddings
|
|
20
|
+
# sparse = Fastembed::TextSparseEmbedding.new
|
|
21
|
+
# embeddings = sparse.embed(["Hello"]).to_a
|
|
22
|
+
#
|
|
23
|
+
# @example Async embedding
|
|
24
|
+
# future = embedding.embed_async(large_document_list)
|
|
25
|
+
# vectors = future.value # blocks until complete
|
|
26
|
+
#
|
|
27
|
+
# @see https://github.com/khasinski/fastembed-rb
|
|
28
|
+
#
|
|
5
29
|
module Fastembed
|
|
30
|
+
# Base error class for all Fastembed errors
|
|
6
31
|
class Error < StandardError; end
|
|
32
|
+
|
|
33
|
+
# Raised when model download fails
|
|
7
34
|
class DownloadError < Error; end
|
|
8
35
|
end
|
|
9
36
|
|
|
37
|
+
require_relative 'fastembed/pooling'
|
|
38
|
+
require_relative 'fastembed/base_model_info'
|
|
10
39
|
require_relative 'fastembed/model_info'
|
|
40
|
+
require_relative 'fastembed/reranker_model_info'
|
|
41
|
+
require_relative 'fastembed/sparse_model_info'
|
|
42
|
+
require_relative 'fastembed/late_interaction_model_info'
|
|
43
|
+
require_relative 'fastembed/custom_model_registry'
|
|
11
44
|
require_relative 'fastembed/model_management'
|
|
12
|
-
require_relative 'fastembed/
|
|
45
|
+
require_relative 'fastembed/quantization'
|
|
46
|
+
require_relative 'fastembed/progress'
|
|
47
|
+
require_relative 'fastembed/async'
|
|
48
|
+
require_relative 'fastembed/validators'
|
|
49
|
+
require_relative 'fastembed/base_model'
|
|
13
50
|
require_relative 'fastembed/onnx_embedding_model'
|
|
14
51
|
require_relative 'fastembed/text_embedding'
|
|
52
|
+
require_relative 'fastembed/text_cross_encoder'
|
|
53
|
+
require_relative 'fastembed/sparse_embedding'
|
|
54
|
+
require_relative 'fastembed/late_interaction_embedding'
|
|
55
|
+
require_relative 'fastembed/image_embedding'
|