fastembed 1.0.0 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +1 -0
- data/.yardopts +6 -0
- data/BENCHMARKS.md +124 -1
- data/CHANGELOG.md +14 -0
- data/README.md +395 -74
- data/benchmark/compare_all.rb +167 -0
- data/benchmark/compare_python.py +60 -0
- data/benchmark/memory_profile.rb +70 -0
- data/benchmark/profile.rb +198 -0
- data/benchmark/reranker_benchmark.rb +158 -0
- data/exe/fastembed +6 -0
- data/fastembed.gemspec +3 -0
- data/lib/fastembed/async.rb +193 -0
- data/lib/fastembed/base_model.rb +247 -0
- data/lib/fastembed/base_model_info.rb +61 -0
- data/lib/fastembed/cli.rb +745 -0
- data/lib/fastembed/custom_model_registry.rb +255 -0
- data/lib/fastembed/image_embedding.rb +313 -0
- data/lib/fastembed/late_interaction_embedding.rb +260 -0
- data/lib/fastembed/late_interaction_model_info.rb +91 -0
- data/lib/fastembed/model_info.rb +59 -19
- data/lib/fastembed/model_management.rb +82 -23
- data/lib/fastembed/onnx_embedding_model.rb +25 -4
- data/lib/fastembed/pooling.rb +39 -3
- data/lib/fastembed/progress.rb +52 -0
- data/lib/fastembed/quantization.rb +75 -0
- data/lib/fastembed/reranker_model_info.rb +91 -0
- data/lib/fastembed/sparse_embedding.rb +261 -0
- data/lib/fastembed/sparse_model_info.rb +80 -0
- data/lib/fastembed/text_cross_encoder.rb +217 -0
- data/lib/fastembed/text_embedding.rb +161 -28
- data/lib/fastembed/validators.rb +59 -0
- data/lib/fastembed/version.rb +1 -1
- data/lib/fastembed.rb +42 -1
- data/plan.md +257 -0
- data/scripts/verify_models.rb +229 -0
- metadata +70 -3
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Represents a late interaction (ColBERT-style) embedding
|
|
5
|
+
# Contains multiple token-level embeddings instead of a single vector
|
|
6
|
+
class LateInteractionEmbedding
|
|
7
|
+
attr_reader :embeddings, :token_count
|
|
8
|
+
|
|
9
|
+
# @param embeddings [Array<Array<Float>>] Token-level embeddings
|
|
10
|
+
def initialize(embeddings)
|
|
11
|
+
@embeddings = embeddings
|
|
12
|
+
@token_count = embeddings.length
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
# Get embedding dimension
|
|
16
|
+
# @return [Integer] Dimension of each token embedding
|
|
17
|
+
def dim
|
|
18
|
+
@embeddings.first&.length || 0
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
# Compute MaxSim score against another late interaction embedding
|
|
22
|
+
# This is the core ColBERT scoring mechanism
|
|
23
|
+
# @param other [LateInteractionEmbedding] Document embedding to score against
|
|
24
|
+
# @return [Float] MaxSim relevance score
|
|
25
|
+
def max_sim(other)
|
|
26
|
+
return 0.0 if embeddings.empty? || other.embeddings.empty?
|
|
27
|
+
|
|
28
|
+
# For each query token, find max similarity with any document token
|
|
29
|
+
embeddings.sum do |query_vec|
|
|
30
|
+
other.embeddings.map do |doc_vec|
|
|
31
|
+
dot_product(query_vec, doc_vec)
|
|
32
|
+
end.max
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def to_a
|
|
37
|
+
@embeddings
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def to_s
|
|
41
|
+
"LateInteractionEmbedding(tokens=#{token_count}, dim=#{dim})"
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def inspect
|
|
45
|
+
to_s
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
private
|
|
49
|
+
|
|
50
|
+
def dot_product(a, b)
|
|
51
|
+
a.zip(b).sum { |x, y| x * y }
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
# Late interaction text embedding using ColBERT-style models
|
|
56
|
+
#
|
|
57
|
+
# Unlike standard embeddings that produce one vector per document,
|
|
58
|
+
# late interaction models produce one vector per token. This enables
|
|
59
|
+
# more fine-grained matching using MaxSim scoring.
|
|
60
|
+
#
|
|
61
|
+
# @example Basic usage
|
|
62
|
+
# model = Fastembed::LateInteractionTextEmbedding.new
|
|
63
|
+
# query_emb = model.query_embed("What is ML?").first
|
|
64
|
+
# doc_emb = model.embed("Machine learning is...").first
|
|
65
|
+
# score = query_emb.max_sim(doc_emb)
|
|
66
|
+
#
|
|
67
|
+
class LateInteractionTextEmbedding
|
|
68
|
+
include BaseModel
|
|
69
|
+
|
|
70
|
+
attr_reader :dim
|
|
71
|
+
|
|
72
|
+
# Initialize a late interaction embedding model
|
|
73
|
+
#
|
|
74
|
+
# @param model_name [String] Name of the model to use
|
|
75
|
+
# @param cache_dir [String, nil] Custom cache directory for models
|
|
76
|
+
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
77
|
+
# @param providers [Array<String>, nil] ONNX execution providers
|
|
78
|
+
# @param show_progress [Boolean] Whether to show download progress
|
|
79
|
+
# @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
|
|
80
|
+
# @param local_model_dir [String, nil] Load model from local directory instead of downloading
|
|
81
|
+
# @param model_file [String, nil] Override model file name (e.g., "model.onnx")
|
|
82
|
+
# @param tokenizer_file [String, nil] Override tokenizer file name (e.g., "tokenizer.json")
|
|
83
|
+
def initialize(
|
|
84
|
+
model_name: DEFAULT_LATE_INTERACTION_MODEL,
|
|
85
|
+
cache_dir: nil,
|
|
86
|
+
threads: nil,
|
|
87
|
+
providers: nil,
|
|
88
|
+
show_progress: true,
|
|
89
|
+
quantization: nil,
|
|
90
|
+
local_model_dir: nil,
|
|
91
|
+
model_file: nil,
|
|
92
|
+
tokenizer_file: nil
|
|
93
|
+
)
|
|
94
|
+
if local_model_dir
|
|
95
|
+
initialize_from_local(
|
|
96
|
+
local_model_dir: local_model_dir,
|
|
97
|
+
model_name: model_name,
|
|
98
|
+
threads: threads,
|
|
99
|
+
providers: providers,
|
|
100
|
+
quantization: quantization,
|
|
101
|
+
model_file: model_file,
|
|
102
|
+
tokenizer_file: tokenizer_file
|
|
103
|
+
)
|
|
104
|
+
else
|
|
105
|
+
initialize_model(
|
|
106
|
+
model_name: model_name,
|
|
107
|
+
cache_dir: cache_dir,
|
|
108
|
+
threads: threads,
|
|
109
|
+
providers: providers,
|
|
110
|
+
show_progress: show_progress,
|
|
111
|
+
quantization: quantization
|
|
112
|
+
)
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
@dim = @model_info.dim
|
|
116
|
+
setup_model_and_tokenizer(model_file_override: model_file || quantized_model_file)
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
# Generate late interaction embeddings for documents
|
|
120
|
+
#
|
|
121
|
+
# @param documents [Array<String>, String] Text document(s) to embed
|
|
122
|
+
# @param batch_size [Integer] Number of documents to process at once
|
|
123
|
+
# @yield [Progress] Optional progress callback called after each batch
|
|
124
|
+
# @return [Enumerator] Lazy enumerator yielding LateInteractionEmbedding objects
|
|
125
|
+
def embed(documents, batch_size: 32, &progress_callback)
|
|
126
|
+
documents = Validators.validate_documents!(documents)
|
|
127
|
+
return Enumerator.new { |_| } if documents.empty?
|
|
128
|
+
|
|
129
|
+
total_batches = (documents.length.to_f / batch_size).ceil
|
|
130
|
+
|
|
131
|
+
Enumerator.new do |yielder|
|
|
132
|
+
documents.each_slice(batch_size).with_index(1) do |batch, batch_num|
|
|
133
|
+
embeddings = compute_embeddings(batch)
|
|
134
|
+
embeddings.each { |emb| yielder << emb }
|
|
135
|
+
|
|
136
|
+
if progress_callback
|
|
137
|
+
progress = Progress.new(current: batch_num, total: total_batches, batch_size: batch_size)
|
|
138
|
+
progress_callback.call(progress)
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
end
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
# Generate late interaction embeddings for queries
|
|
145
|
+
# Queries typically use a special prefix for asymmetric retrieval
|
|
146
|
+
#
|
|
147
|
+
# @param queries [Array<String>, String] Query text(s) to embed
|
|
148
|
+
# @param batch_size [Integer] Number of queries to process at once
|
|
149
|
+
# @return [Enumerator] Lazy enumerator yielding LateInteractionEmbedding objects
|
|
150
|
+
def query_embed(queries, batch_size: 32)
|
|
151
|
+
queries = [queries] if queries.is_a?(String)
|
|
152
|
+
# ColBERT uses [Q] marker for queries
|
|
153
|
+
prefixed = queries.map { |q| "[Q] #{q}" }
|
|
154
|
+
embed(prefixed, batch_size: batch_size)
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
# Generate late interaction embeddings for passages/documents
|
|
158
|
+
#
|
|
159
|
+
# @param passages [Array<String>, String] Passage text(s) to embed
|
|
160
|
+
# @param batch_size [Integer] Number of passages to process at once
|
|
161
|
+
# @return [Enumerator] Lazy enumerator yielding LateInteractionEmbedding objects
|
|
162
|
+
def passage_embed(passages, batch_size: 32)
|
|
163
|
+
passages = [passages] if passages.is_a?(String)
|
|
164
|
+
# ColBERT uses [D] marker for documents
|
|
165
|
+
prefixed = passages.map { |p| "[D] #{p}" }
|
|
166
|
+
embed(prefixed, batch_size: batch_size)
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
# Generate embeddings asynchronously
|
|
170
|
+
#
|
|
171
|
+
# @param documents [Array<String>, String] Text document(s) to embed
|
|
172
|
+
# @param batch_size [Integer] Number of documents to process at once
|
|
173
|
+
# @return [Async::Future] Future that resolves to array of LateInteractionEmbedding objects
|
|
174
|
+
def embed_async(documents, batch_size: 32)
|
|
175
|
+
Async::Future.new { embed(documents, batch_size: batch_size).to_a }
|
|
176
|
+
end
|
|
177
|
+
|
|
178
|
+
# Generate query embeddings asynchronously
|
|
179
|
+
#
|
|
180
|
+
# @param queries [Array<String>, String] Query text(s) to embed
|
|
181
|
+
# @param batch_size [Integer] Number of queries to process at once
|
|
182
|
+
# @return [Async::Future] Future that resolves to array of LateInteractionEmbedding objects
|
|
183
|
+
def query_embed_async(queries, batch_size: 32)
|
|
184
|
+
Async::Future.new { query_embed(queries, batch_size: batch_size).to_a }
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
# Generate passage embeddings asynchronously
|
|
188
|
+
#
|
|
189
|
+
# @param passages [Array<String>, String] Passage text(s) to embed
|
|
190
|
+
# @param batch_size [Integer] Number of passages to process at once
|
|
191
|
+
# @return [Async::Future] Future that resolves to array of LateInteractionEmbedding objects
|
|
192
|
+
def passage_embed_async(passages, batch_size: 32)
|
|
193
|
+
Async::Future.new { passage_embed(passages, batch_size: batch_size).to_a }
|
|
194
|
+
end
|
|
195
|
+
|
|
196
|
+
# List all supported late interaction models
|
|
197
|
+
#
|
|
198
|
+
# @return [Array<Hash>] Array of model information hashes
|
|
199
|
+
def self.list_supported_models
|
|
200
|
+
SUPPORTED_LATE_INTERACTION_MODELS.values.map(&:to_h)
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
private
|
|
204
|
+
|
|
205
|
+
def resolve_model_info(model_name)
|
|
206
|
+
# Check built-in registry first
|
|
207
|
+
info = SUPPORTED_LATE_INTERACTION_MODELS[model_name]
|
|
208
|
+
return info if info
|
|
209
|
+
|
|
210
|
+
# Check custom registry
|
|
211
|
+
info = CustomModelRegistry.late_interaction_models[model_name]
|
|
212
|
+
return info if info
|
|
213
|
+
|
|
214
|
+
raise Error, "Unknown late interaction model: #{model_name}"
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
def create_local_model_info(model_name:, model_file:, tokenizer_file:)
|
|
218
|
+
LateInteractionModelInfo.new(
|
|
219
|
+
model_name: model_name,
|
|
220
|
+
description: 'Local late interaction model',
|
|
221
|
+
size_in_gb: 0,
|
|
222
|
+
sources: {},
|
|
223
|
+
model_file: model_file || 'model.onnx',
|
|
224
|
+
tokenizer_file: tokenizer_file || 'tokenizer.json',
|
|
225
|
+
dim: 128 # Default ColBERT dimension
|
|
226
|
+
)
|
|
227
|
+
end
|
|
228
|
+
|
|
229
|
+
def compute_embeddings(texts)
|
|
230
|
+
prepared = tokenize_and_prepare(texts)
|
|
231
|
+
outputs = @session.run(nil, prepared[:inputs])
|
|
232
|
+
token_embeddings = extract_token_embeddings(outputs)
|
|
233
|
+
|
|
234
|
+
# Create LateInteractionEmbedding for each document
|
|
235
|
+
texts.length.times.map do |i|
|
|
236
|
+
# Filter out padding tokens using attention mask
|
|
237
|
+
valid_embeddings = []
|
|
238
|
+
token_embeddings[i].each_with_index do |emb, j|
|
|
239
|
+
valid_embeddings << normalize_vector(emb) if prepared[:attention_mask][i][j] == 1
|
|
240
|
+
end
|
|
241
|
+
LateInteractionEmbedding.new(valid_embeddings)
|
|
242
|
+
end
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
def extract_token_embeddings(outputs)
|
|
246
|
+
if outputs.is_a?(Hash)
|
|
247
|
+
outputs['last_hidden_state'] || outputs['token_embeddings'] || outputs.values.first
|
|
248
|
+
else
|
|
249
|
+
outputs.first
|
|
250
|
+
end
|
|
251
|
+
end
|
|
252
|
+
|
|
253
|
+
def normalize_vector(vec)
|
|
254
|
+
norm = Math.sqrt(vec.sum { |x| x * x })
|
|
255
|
+
return vec if norm.zero?
|
|
256
|
+
|
|
257
|
+
vec.map { |x| x / norm }
|
|
258
|
+
end
|
|
259
|
+
end
|
|
260
|
+
end
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Model information for late interaction models (ColBERT, etc.)
|
|
5
|
+
#
|
|
6
|
+
# Late interaction models produce token-level embeddings instead of a single
|
|
7
|
+
# document vector, enabling fine-grained matching via MaxSim scoring.
|
|
8
|
+
#
|
|
9
|
+
# @example Access late interaction model info
|
|
10
|
+
# info = Fastembed::SUPPORTED_LATE_INTERACTION_MODELS['colbert-ir/colbertv2.0']
|
|
11
|
+
# info.dim # => 128
|
|
12
|
+
#
|
|
13
|
+
class LateInteractionModelInfo
|
|
14
|
+
include BaseModelInfo
|
|
15
|
+
|
|
16
|
+
# @!attribute [r] dim
|
|
17
|
+
# @return [Integer] Output embedding dimension per token
|
|
18
|
+
attr_reader :dim
|
|
19
|
+
|
|
20
|
+
# Create a new LateInteractionModelInfo instance
|
|
21
|
+
#
|
|
22
|
+
# @param model_name [String] Full model identifier
|
|
23
|
+
# @param dim [Integer] Output embedding dimension per token
|
|
24
|
+
# @param description [String] Human-readable description
|
|
25
|
+
# @param size_in_gb [Float] Model size in GB
|
|
26
|
+
# @param sources [Hash] Source repositories
|
|
27
|
+
# @param model_file [String] Path to ONNX model file
|
|
28
|
+
# @param tokenizer_file [String] Path to tokenizer file
|
|
29
|
+
# @param max_length [Integer] Maximum sequence length
|
|
30
|
+
def initialize(
|
|
31
|
+
model_name:,
|
|
32
|
+
dim:,
|
|
33
|
+
description:,
|
|
34
|
+
size_in_gb:,
|
|
35
|
+
sources:,
|
|
36
|
+
model_file: 'onnx/model.onnx',
|
|
37
|
+
tokenizer_file: 'tokenizer.json',
|
|
38
|
+
max_length: 512
|
|
39
|
+
)
|
|
40
|
+
initialize_base(
|
|
41
|
+
model_name: model_name,
|
|
42
|
+
description: description,
|
|
43
|
+
size_in_gb: size_in_gb,
|
|
44
|
+
sources: sources,
|
|
45
|
+
model_file: model_file,
|
|
46
|
+
tokenizer_file: tokenizer_file,
|
|
47
|
+
max_length: max_length
|
|
48
|
+
)
|
|
49
|
+
@dim = dim
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Convert to hash representation
|
|
53
|
+
# @return [Hash] Model info as a hash
|
|
54
|
+
def to_h
|
|
55
|
+
{
|
|
56
|
+
model_name: model_name,
|
|
57
|
+
dim: dim,
|
|
58
|
+
description: description,
|
|
59
|
+
size_in_gb: size_in_gb,
|
|
60
|
+
sources: sources,
|
|
61
|
+
model_file: model_file,
|
|
62
|
+
tokenizer_file: tokenizer_file,
|
|
63
|
+
max_length: max_length
|
|
64
|
+
}
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
# Registry of supported late interaction models
|
|
69
|
+
SUPPORTED_LATE_INTERACTION_MODELS = {
|
|
70
|
+
'colbert-ir/colbertv2.0' => LateInteractionModelInfo.new(
|
|
71
|
+
model_name: 'colbert-ir/colbertv2.0',
|
|
72
|
+
dim: 128,
|
|
73
|
+
description: 'ColBERTv2 for late interaction retrieval',
|
|
74
|
+
size_in_gb: 0.44,
|
|
75
|
+
sources: { hf: 'colbert-ir/colbertv2.0' },
|
|
76
|
+
model_file: 'model.onnx',
|
|
77
|
+
max_length: 512
|
|
78
|
+
),
|
|
79
|
+
'jinaai/jina-colbert-v1-en' => LateInteractionModelInfo.new(
|
|
80
|
+
model_name: 'jinaai/jina-colbert-v1-en',
|
|
81
|
+
dim: 768,
|
|
82
|
+
description: 'Jina ColBERT v1 for English with 8192 context',
|
|
83
|
+
size_in_gb: 0.43,
|
|
84
|
+
sources: { hf: 'onnx-models/jina-colbert-v1-en-onnx' },
|
|
85
|
+
model_file: 'model.onnx',
|
|
86
|
+
max_length: 8192
|
|
87
|
+
)
|
|
88
|
+
}.freeze
|
|
89
|
+
|
|
90
|
+
DEFAULT_LATE_INTERACTION_MODEL = 'colbert-ir/colbertv2.0'
|
|
91
|
+
end
|
data/lib/fastembed/model_info.rb
CHANGED
|
@@ -1,11 +1,41 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
3
|
module Fastembed
|
|
4
|
-
# Model information
|
|
4
|
+
# Model information for dense embedding models
|
|
5
|
+
#
|
|
6
|
+
# Stores metadata and configuration for ONNX embedding models including
|
|
7
|
+
# output dimensions, pooling strategy, and normalization settings.
|
|
8
|
+
#
|
|
9
|
+
# @example Access model info
|
|
10
|
+
# info = Fastembed::SUPPORTED_MODELS['BAAI/bge-small-en-v1.5']
|
|
11
|
+
# info.dim # => 384
|
|
12
|
+
# info.pooling # => :mean
|
|
13
|
+
# info.normalize # => true
|
|
14
|
+
#
|
|
5
15
|
class ModelInfo
|
|
6
|
-
|
|
7
|
-
:tokenizer_file, :sources, :pooling, :normalize
|
|
16
|
+
include BaseModelInfo
|
|
8
17
|
|
|
18
|
+
# @!attribute [r] dim
|
|
19
|
+
# @return [Integer] Output embedding dimension
|
|
20
|
+
# @!attribute [r] pooling
|
|
21
|
+
# @return [Symbol] Pooling strategy (:mean or :cls)
|
|
22
|
+
# @!attribute [r] normalize
|
|
23
|
+
# @return [Boolean] Whether to L2 normalize output embeddings
|
|
24
|
+
attr_reader :dim, :pooling, :normalize
|
|
25
|
+
|
|
26
|
+
# Create a new ModelInfo instance
|
|
27
|
+
#
|
|
28
|
+
# @param model_name [String] Full model identifier
|
|
29
|
+
# @param dim [Integer] Output embedding dimension
|
|
30
|
+
# @param description [String] Human-readable description
|
|
31
|
+
# @param size_in_gb [Float] Model size in GB
|
|
32
|
+
# @param sources [Hash] Source repositories
|
|
33
|
+
# @param model_file [String] Path to ONNX model file
|
|
34
|
+
# @param tokenizer_file [String] Path to tokenizer file
|
|
35
|
+
# @param pooling [Symbol] Pooling strategy (:mean or :cls)
|
|
36
|
+
# @param normalize [Boolean] Whether to L2 normalize outputs
|
|
37
|
+
# @param max_length [Integer] Maximum sequence length
|
|
38
|
+
# @raise [ArgumentError] If pooling strategy is invalid
|
|
9
39
|
def initialize(
|
|
10
40
|
model_name:,
|
|
11
41
|
dim:,
|
|
@@ -15,23 +45,28 @@ module Fastembed
|
|
|
15
45
|
model_file: 'model.onnx',
|
|
16
46
|
tokenizer_file: 'tokenizer.json',
|
|
17
47
|
pooling: :mean,
|
|
18
|
-
normalize: true
|
|
48
|
+
normalize: true,
|
|
49
|
+
max_length: BaseModelInfo::DEFAULT_MAX_LENGTH
|
|
19
50
|
)
|
|
20
|
-
|
|
51
|
+
unless Pooling.valid?(pooling)
|
|
52
|
+
valid = Pooling::VALID_STRATEGIES.join(', ')
|
|
53
|
+
raise ArgumentError, "Invalid pooling strategy: #{pooling}. Valid strategies: #{valid}"
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
initialize_base(
|
|
57
|
+
model_name: model_name,
|
|
58
|
+
description: description,
|
|
59
|
+
size_in_gb: size_in_gb,
|
|
60
|
+
sources: sources,
|
|
61
|
+
model_file: model_file,
|
|
62
|
+
tokenizer_file: tokenizer_file,
|
|
63
|
+
max_length: max_length
|
|
64
|
+
)
|
|
21
65
|
@dim = dim
|
|
22
|
-
@description = description
|
|
23
|
-
@size_in_gb = size_in_gb
|
|
24
|
-
@sources = sources
|
|
25
|
-
@model_file = model_file
|
|
26
|
-
@tokenizer_file = tokenizer_file
|
|
27
66
|
@pooling = pooling
|
|
28
67
|
@normalize = normalize
|
|
29
68
|
end
|
|
30
69
|
|
|
31
|
-
def hf_repo
|
|
32
|
-
sources[:hf]
|
|
33
|
-
end
|
|
34
|
-
|
|
35
70
|
def to_h
|
|
36
71
|
{
|
|
37
72
|
model_name: model_name,
|
|
@@ -42,7 +77,8 @@ module Fastembed
|
|
|
42
77
|
model_file: model_file,
|
|
43
78
|
tokenizer_file: tokenizer_file,
|
|
44
79
|
pooling: pooling,
|
|
45
|
-
normalize: normalize
|
|
80
|
+
normalize: normalize,
|
|
81
|
+
max_length: max_length
|
|
46
82
|
}
|
|
47
83
|
end
|
|
48
84
|
end
|
|
@@ -103,7 +139,8 @@ module Fastembed
|
|
|
103
139
|
description: 'Long context (8192 tokens) English embedding model',
|
|
104
140
|
size_in_gb: 0.52,
|
|
105
141
|
sources: { hf: 'nomic-ai/nomic-embed-text-v1' },
|
|
106
|
-
model_file: 'onnx/model.onnx'
|
|
142
|
+
model_file: 'onnx/model.onnx',
|
|
143
|
+
max_length: 8192
|
|
107
144
|
),
|
|
108
145
|
'nomic-ai/nomic-embed-text-v1.5' => ModelInfo.new(
|
|
109
146
|
model_name: 'nomic-ai/nomic-embed-text-v1.5',
|
|
@@ -111,7 +148,8 @@ module Fastembed
|
|
|
111
148
|
description: 'Improved long context embedding with Matryoshka support',
|
|
112
149
|
size_in_gb: 0.52,
|
|
113
150
|
sources: { hf: 'nomic-ai/nomic-embed-text-v1.5' },
|
|
114
|
-
model_file: 'onnx/model.onnx'
|
|
151
|
+
model_file: 'onnx/model.onnx',
|
|
152
|
+
max_length: 8192
|
|
115
153
|
),
|
|
116
154
|
'jinaai/jina-embeddings-v2-small-en' => ModelInfo.new(
|
|
117
155
|
model_name: 'jinaai/jina-embeddings-v2-small-en',
|
|
@@ -119,7 +157,8 @@ module Fastembed
|
|
|
119
157
|
description: 'Small English embedding with 8192 token context',
|
|
120
158
|
size_in_gb: 0.06,
|
|
121
159
|
sources: { hf: 'Xenova/jina-embeddings-v2-small-en' },
|
|
122
|
-
model_file: 'onnx/model.onnx'
|
|
160
|
+
model_file: 'onnx/model.onnx',
|
|
161
|
+
max_length: 8192
|
|
123
162
|
),
|
|
124
163
|
'jinaai/jina-embeddings-v2-base-en' => ModelInfo.new(
|
|
125
164
|
model_name: 'jinaai/jina-embeddings-v2-base-en',
|
|
@@ -127,7 +166,8 @@ module Fastembed
|
|
|
127
166
|
description: 'Base English embedding with 8192 token context',
|
|
128
167
|
size_in_gb: 0.52,
|
|
129
168
|
sources: { hf: 'Xenova/jina-embeddings-v2-base-en' },
|
|
130
|
-
model_file: 'onnx/model.onnx'
|
|
169
|
+
model_file: 'onnx/model.onnx',
|
|
170
|
+
max_length: 8192
|
|
131
171
|
),
|
|
132
172
|
'sentence-transformers/paraphrase-MiniLM-L6-v2' => ModelInfo.new(
|
|
133
173
|
model_name: 'sentence-transformers/paraphrase-MiniLM-L6-v2',
|
|
@@ -6,9 +6,24 @@ require 'json'
|
|
|
6
6
|
require 'fileutils'
|
|
7
7
|
|
|
8
8
|
module Fastembed
|
|
9
|
-
# Handles model downloading and caching
|
|
9
|
+
# Handles model downloading and caching from HuggingFace
|
|
10
|
+
#
|
|
11
|
+
# Downloads ONNX models and tokenizer files from HuggingFace repositories,
|
|
12
|
+
# caching them locally for subsequent use. Supports custom cache directories
|
|
13
|
+
# via environment variables.
|
|
14
|
+
#
|
|
15
|
+
# @example Check cache location
|
|
16
|
+
# Fastembed::ModelManagement.cache_dir
|
|
17
|
+
# # => "/home/user/.cache/fastembed"
|
|
18
|
+
#
|
|
19
|
+
# @example Use custom cache directory
|
|
20
|
+
# Fastembed::ModelManagement.cache_dir = "/custom/path"
|
|
21
|
+
#
|
|
10
22
|
module ModelManagement
|
|
23
|
+
# Base URL for HuggingFace API
|
|
11
24
|
HF_API_BASE = 'https://huggingface.co'
|
|
25
|
+
|
|
26
|
+
# Files required for model operation (in addition to model.onnx and tokenizer.json)
|
|
12
27
|
REQUIRED_FILES = %w[
|
|
13
28
|
config.json
|
|
14
29
|
tokenizer.json
|
|
@@ -18,7 +33,13 @@ module Fastembed
|
|
|
18
33
|
|
|
19
34
|
class << self
|
|
20
35
|
# Returns the cache directory for storing models
|
|
21
|
-
#
|
|
36
|
+
#
|
|
37
|
+
# Priority order:
|
|
38
|
+
# 1. FASTEMBED_CACHE_PATH environment variable
|
|
39
|
+
# 2. XDG_CACHE_HOME environment variable
|
|
40
|
+
# 3. ~/.cache (fallback)
|
|
41
|
+
#
|
|
42
|
+
# @return [String] Absolute path to cache directory
|
|
22
43
|
def cache_dir
|
|
23
44
|
@cache_dir ||= begin
|
|
24
45
|
base = ENV['FASTEMBED_CACHE_PATH'] ||
|
|
@@ -29,11 +50,22 @@ module Fastembed
|
|
|
29
50
|
end
|
|
30
51
|
|
|
31
52
|
# Set a custom cache directory
|
|
53
|
+
# @!attribute [w] cache_dir
|
|
54
|
+
# @return [String] Path to use as cache directory
|
|
32
55
|
attr_writer :cache_dir
|
|
33
56
|
|
|
34
57
|
# Returns the path to a cached model, downloading if necessary
|
|
35
|
-
|
|
36
|
-
|
|
58
|
+
#
|
|
59
|
+
# Downloads the model from HuggingFace if not already cached.
|
|
60
|
+
# The model directory will contain the ONNX model file and tokenizer.
|
|
61
|
+
#
|
|
62
|
+
# @param model_name [String] Name of the model (e.g., "BAAI/bge-small-en-v1.5")
|
|
63
|
+
# @param model_info [BaseModelInfo, nil] Optional pre-resolved model info
|
|
64
|
+
# @param show_progress [Boolean] Whether to print download progress
|
|
65
|
+
# @return [String] Absolute path to the model directory
|
|
66
|
+
# @raise [DownloadError] If the download fails
|
|
67
|
+
def retrieve_model(model_name, model_info: nil, show_progress: true)
|
|
68
|
+
model_info ||= resolve_model_info(model_name)
|
|
37
69
|
model_dir = model_directory(model_info)
|
|
38
70
|
|
|
39
71
|
# Check if model is already cached
|
|
@@ -45,6 +77,10 @@ module Fastembed
|
|
|
45
77
|
end
|
|
46
78
|
|
|
47
79
|
# Check if a model exists in cache
|
|
80
|
+
#
|
|
81
|
+
# @param model_dir [String] Path to model directory
|
|
82
|
+
# @param model_info [BaseModelInfo] Model info with required file paths
|
|
83
|
+
# @return [Boolean] True if model files exist
|
|
48
84
|
def model_cached?(model_dir, model_info)
|
|
49
85
|
return false unless Dir.exist?(model_dir)
|
|
50
86
|
|
|
@@ -56,21 +92,33 @@ module Fastembed
|
|
|
56
92
|
end
|
|
57
93
|
|
|
58
94
|
# Get the directory path for a model
|
|
95
|
+
#
|
|
96
|
+
# @param model_info [BaseModelInfo] Model info
|
|
97
|
+
# @return [String] Path where model should be stored
|
|
59
98
|
def model_directory(model_info)
|
|
60
99
|
# Create a safe directory name from the model name
|
|
61
100
|
safe_name = model_info.model_name.gsub('/', '--')
|
|
62
101
|
File.join(cache_dir, 'models', safe_name)
|
|
63
102
|
end
|
|
64
103
|
|
|
65
|
-
# Resolve model name to ModelInfo
|
|
104
|
+
# Resolve model name to ModelInfo from registry
|
|
105
|
+
#
|
|
106
|
+
# Checks both built-in registry and custom model registry.
|
|
107
|
+
#
|
|
108
|
+
# @param model_name [String] Model name to look up
|
|
109
|
+
# @return [ModelInfo] The model information
|
|
110
|
+
# @raise [ArgumentError] If model is not found in any registry
|
|
66
111
|
def resolve_model_info(model_name)
|
|
112
|
+
# Check built-in registry first
|
|
67
113
|
model_info = SUPPORTED_MODELS[model_name]
|
|
68
|
-
|
|
69
|
-
raise ArgumentError,
|
|
70
|
-
"Unknown model: #{model_name}. Use TextEmbedding.list_supported_models to see available models."
|
|
71
|
-
end
|
|
114
|
+
return model_info if model_info
|
|
72
115
|
|
|
73
|
-
|
|
116
|
+
# Check custom registry
|
|
117
|
+
model_info = CustomModelRegistry.embedding_models[model_name]
|
|
118
|
+
return model_info if model_info
|
|
119
|
+
|
|
120
|
+
raise ArgumentError,
|
|
121
|
+
"Unknown model: #{model_name}. Use TextEmbedding.list_supported_models to see available models."
|
|
74
122
|
end
|
|
75
123
|
|
|
76
124
|
private
|
|
@@ -84,6 +132,11 @@ module Fastembed
|
|
|
84
132
|
# Download model file
|
|
85
133
|
download_file(repo_id, model_info.model_file, model_dir, show_progress: show_progress)
|
|
86
134
|
|
|
135
|
+
# Some large models store weights in a separate .onnx_data file
|
|
136
|
+
# Try to download it if it exists (not required)
|
|
137
|
+
data_file = "#{model_info.model_file}_data"
|
|
138
|
+
download_file(repo_id, data_file, model_dir, show_progress: show_progress, required: false)
|
|
139
|
+
|
|
87
140
|
# Download tokenizer and config files
|
|
88
141
|
files_to_download = REQUIRED_FILES + [model_info.tokenizer_file]
|
|
89
142
|
files_to_download.uniq.each do |file|
|
|
@@ -129,23 +182,29 @@ module Fastembed
|
|
|
129
182
|
raise DownloadError, "Invalid URL scheme: #{url}"
|
|
130
183
|
end
|
|
131
184
|
|
|
132
|
-
|
|
185
|
+
# Use longer timeout for large files
|
|
186
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == 'https', read_timeout: 600,
|
|
133
187
|
open_timeout: 30) do |http|
|
|
134
188
|
request = Net::HTTP::Get.new(uri)
|
|
135
189
|
request['User-Agent'] = "fastembed-ruby/#{VERSION}"
|
|
136
190
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
191
|
+
http.request(request) do |response|
|
|
192
|
+
case response
|
|
193
|
+
when Net::HTTPSuccess
|
|
194
|
+
# Stream to file to handle large files without loading into memory
|
|
195
|
+
File.open(local_path, 'wb') do |file|
|
|
196
|
+
response.read_body do |chunk|
|
|
197
|
+
file.write(chunk)
|
|
198
|
+
end
|
|
199
|
+
end
|
|
200
|
+
when Net::HTTPRedirection
|
|
201
|
+
new_url = response['location']
|
|
202
|
+
# Handle relative redirects
|
|
203
|
+
new_url = "#{uri.scheme}://#{uri.host}#{new_url}" if new_url.start_with?('/')
|
|
204
|
+
download_with_redirect(new_url, local_path, show_progress: show_progress, max_redirects: max_redirects - 1)
|
|
205
|
+
else
|
|
206
|
+
raise DownloadError, "HTTP #{response.code}: #{response.message}"
|
|
207
|
+
end
|
|
149
208
|
end
|
|
150
209
|
end
|
|
151
210
|
end
|