fastembed 1.0.0 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +1 -0
- data/.yardopts +6 -0
- data/BENCHMARKS.md +124 -1
- data/CHANGELOG.md +14 -0
- data/README.md +395 -74
- data/benchmark/compare_all.rb +167 -0
- data/benchmark/compare_python.py +60 -0
- data/benchmark/memory_profile.rb +70 -0
- data/benchmark/profile.rb +198 -0
- data/benchmark/reranker_benchmark.rb +158 -0
- data/exe/fastembed +6 -0
- data/fastembed.gemspec +3 -0
- data/lib/fastembed/async.rb +193 -0
- data/lib/fastembed/base_model.rb +247 -0
- data/lib/fastembed/base_model_info.rb +61 -0
- data/lib/fastembed/cli.rb +745 -0
- data/lib/fastembed/custom_model_registry.rb +255 -0
- data/lib/fastembed/image_embedding.rb +313 -0
- data/lib/fastembed/late_interaction_embedding.rb +260 -0
- data/lib/fastembed/late_interaction_model_info.rb +91 -0
- data/lib/fastembed/model_info.rb +59 -19
- data/lib/fastembed/model_management.rb +82 -23
- data/lib/fastembed/onnx_embedding_model.rb +25 -4
- data/lib/fastembed/pooling.rb +39 -3
- data/lib/fastembed/progress.rb +52 -0
- data/lib/fastembed/quantization.rb +75 -0
- data/lib/fastembed/reranker_model_info.rb +91 -0
- data/lib/fastembed/sparse_embedding.rb +261 -0
- data/lib/fastembed/sparse_model_info.rb +80 -0
- data/lib/fastembed/text_cross_encoder.rb +217 -0
- data/lib/fastembed/text_embedding.rb +161 -28
- data/lib/fastembed/validators.rb +59 -0
- data/lib/fastembed/version.rb +1 -1
- data/lib/fastembed.rb +42 -1
- data/plan.md +257 -0
- data/scripts/verify_models.rb +229 -0
- metadata +70 -3
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Async support for embedding operations
|
|
5
|
+
#
|
|
6
|
+
# Provides Future-like objects for running embeddings in background threads.
|
|
7
|
+
# Useful for parallelizing embedding generation across multiple documents.
|
|
8
|
+
#
|
|
9
|
+
# @example Async embedding
|
|
10
|
+
# embedding = Fastembed::TextEmbedding.new
|
|
11
|
+
# future = embedding.embed_async(documents)
|
|
12
|
+
# # ... do other work ...
|
|
13
|
+
# vectors = future.value # blocks until complete
|
|
14
|
+
#
|
|
15
|
+
# @example Multiple concurrent embeddings
|
|
16
|
+
# futures = documents.each_slice(100).map do |batch|
|
|
17
|
+
# embedding.embed_async(batch)
|
|
18
|
+
# end
|
|
19
|
+
# results = futures.flat_map(&:value)
|
|
20
|
+
#
|
|
21
|
+
module Async
|
|
22
|
+
# A Future representing an async embedding operation
|
|
23
|
+
#
|
|
24
|
+
# Wraps a background thread that performs embedding, providing
|
|
25
|
+
# methods to check completion status and retrieve results.
|
|
26
|
+
#
|
|
27
|
+
class Future
|
|
28
|
+
# @return [Thread] The background thread
|
|
29
|
+
attr_reader :thread
|
|
30
|
+
|
|
31
|
+
# Create a new Future
|
|
32
|
+
#
|
|
33
|
+
# @yield Block to execute in background thread
|
|
34
|
+
# @return [Future]
|
|
35
|
+
def initialize(&block)
|
|
36
|
+
@result = nil
|
|
37
|
+
@error = nil
|
|
38
|
+
@completed = false
|
|
39
|
+
@mutex = Mutex.new
|
|
40
|
+
@condition = ConditionVariable.new
|
|
41
|
+
|
|
42
|
+
@thread = Thread.new do
|
|
43
|
+
result = block.call
|
|
44
|
+
@mutex.synchronize do
|
|
45
|
+
@result = result
|
|
46
|
+
@completed = true
|
|
47
|
+
@condition.broadcast
|
|
48
|
+
end
|
|
49
|
+
rescue StandardError => e
|
|
50
|
+
@mutex.synchronize do
|
|
51
|
+
@error = e
|
|
52
|
+
@completed = true
|
|
53
|
+
@condition.broadcast
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# Check if the operation is complete
|
|
59
|
+
#
|
|
60
|
+
# @return [Boolean] True if complete (success or failure)
|
|
61
|
+
def complete?
|
|
62
|
+
@mutex.synchronize { @completed }
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
alias completed? complete?
|
|
66
|
+
|
|
67
|
+
# Check if the operation is still running
|
|
68
|
+
#
|
|
69
|
+
# @return [Boolean] True if still running
|
|
70
|
+
def pending?
|
|
71
|
+
!complete?
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Check if the operation completed successfully
|
|
75
|
+
#
|
|
76
|
+
# @return [Boolean] True if completed without error
|
|
77
|
+
def success?
|
|
78
|
+
@mutex.synchronize { @completed && @error.nil? }
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
# Check if the operation failed
|
|
82
|
+
#
|
|
83
|
+
# @return [Boolean] True if completed with error
|
|
84
|
+
def failure?
|
|
85
|
+
@mutex.synchronize { @completed && !@error.nil? }
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# Get the result, blocking until complete
|
|
89
|
+
#
|
|
90
|
+
# @param timeout [Numeric, nil] Maximum seconds to wait (nil = forever)
|
|
91
|
+
# @return [Object] The result of the async operation
|
|
92
|
+
# @raise [StandardError] If the operation raised an error
|
|
93
|
+
# @raise [Timeout::Error] If timeout expires before completion
|
|
94
|
+
def value(timeout: nil)
|
|
95
|
+
wait(timeout: timeout)
|
|
96
|
+
raise @error if @error
|
|
97
|
+
|
|
98
|
+
@result
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
alias result value
|
|
102
|
+
|
|
103
|
+
# Wait for completion without retrieving the result
|
|
104
|
+
#
|
|
105
|
+
# @param timeout [Numeric, nil] Maximum seconds to wait (nil = forever)
|
|
106
|
+
# @return [Boolean] True if completed, false if timed out
|
|
107
|
+
def wait(timeout: nil)
|
|
108
|
+
@mutex.synchronize do
|
|
109
|
+
return true if @completed
|
|
110
|
+
|
|
111
|
+
if timeout
|
|
112
|
+
deadline = Time.now + timeout
|
|
113
|
+
until @completed
|
|
114
|
+
remaining = deadline - Time.now
|
|
115
|
+
break if remaining <= 0
|
|
116
|
+
|
|
117
|
+
@condition.wait(@mutex, remaining)
|
|
118
|
+
end
|
|
119
|
+
else
|
|
120
|
+
@condition.wait(@mutex) until @completed
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
@completed
|
|
124
|
+
end
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
# Get the error if the operation failed
|
|
128
|
+
#
|
|
129
|
+
# @return [StandardError, nil] The error, or nil if successful/pending
|
|
130
|
+
def error
|
|
131
|
+
@mutex.synchronize { @error }
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
# Apply a transformation to the result
|
|
135
|
+
#
|
|
136
|
+
# @yield [result] Block to transform the result
|
|
137
|
+
# @return [Future] A new Future with the transformed result
|
|
138
|
+
def then(&block)
|
|
139
|
+
Future.new do
|
|
140
|
+
block.call(value)
|
|
141
|
+
end
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
# Handle errors
|
|
145
|
+
#
|
|
146
|
+
# @yield [error] Block to handle errors
|
|
147
|
+
# @return [Future] A new Future that handles errors
|
|
148
|
+
def rescue(&block)
|
|
149
|
+
Future.new do
|
|
150
|
+
value
|
|
151
|
+
rescue StandardError => e
|
|
152
|
+
block.call(e)
|
|
153
|
+
end
|
|
154
|
+
end
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
# Run multiple futures concurrently and wait for all to complete
|
|
158
|
+
#
|
|
159
|
+
# @param futures [Array<Future>] Futures to wait for
|
|
160
|
+
# @param timeout [Numeric, nil] Maximum seconds to wait
|
|
161
|
+
# @return [Array] Results from all futures
|
|
162
|
+
# @raise [StandardError] If any future raised an error
|
|
163
|
+
def self.all(futures, timeout: nil)
|
|
164
|
+
futures.each { |f| f.wait(timeout: timeout) }
|
|
165
|
+
futures.map(&:value)
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
# Run multiple futures concurrently and return first completed
|
|
169
|
+
#
|
|
170
|
+
# @param futures [Array<Future>] Futures to race
|
|
171
|
+
# @param timeout [Numeric, nil] Maximum seconds to wait
|
|
172
|
+
# @return [Object] Result from first completed future
|
|
173
|
+
# @raise [Timeout::Error] If timeout expires before any future completes
|
|
174
|
+
def self.race(futures, timeout: nil)
|
|
175
|
+
raise ArgumentError, 'No futures provided' if futures.empty?
|
|
176
|
+
|
|
177
|
+
deadline = timeout ? Time.now + timeout : nil
|
|
178
|
+
sleep_time = 0.001 # Start with 1ms
|
|
179
|
+
|
|
180
|
+
loop do
|
|
181
|
+
futures.each do |future|
|
|
182
|
+
return future.value if future.complete?
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
raise Timeout::Error, 'No future completed within timeout' if deadline && Time.now >= deadline
|
|
186
|
+
|
|
187
|
+
sleep sleep_time
|
|
188
|
+
# Exponential backoff up to 10ms to reduce CPU usage for long waits
|
|
189
|
+
sleep_time = [sleep_time * 1.5, 0.01].min
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
end
|
|
193
|
+
end
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'onnxruntime'
|
|
4
|
+
require 'tokenizers'
|
|
5
|
+
|
|
6
|
+
module Fastembed
|
|
7
|
+
# Shared functionality for model classes
|
|
8
|
+
#
|
|
9
|
+
# This module provides common initialization and utility methods used by
|
|
10
|
+
# all model types (TextEmbedding, TextCrossEncoder, TextSparseEmbedding, etc.).
|
|
11
|
+
# It handles model downloading, ONNX session creation, and tokenizer loading.
|
|
12
|
+
#
|
|
13
|
+
# @abstract Include in model classes and call {#initialize_model}
|
|
14
|
+
#
|
|
15
|
+
module BaseModel
|
|
16
|
+
# @!attribute [r] model_name
|
|
17
|
+
# @return [String] Name of the loaded model
|
|
18
|
+
# @!attribute [r] model_info
|
|
19
|
+
# @return [BaseModelInfo] Model metadata and configuration
|
|
20
|
+
# @!attribute [r] quantization
|
|
21
|
+
# @return [Symbol] Current quantization type
|
|
22
|
+
attr_reader :model_name, :model_info, :quantization
|
|
23
|
+
|
|
24
|
+
private
|
|
25
|
+
|
|
26
|
+
# Common initialization logic for all model types
|
|
27
|
+
# @param model_name [String] Name of the model
|
|
28
|
+
# @param cache_dir [String, nil] Custom cache directory
|
|
29
|
+
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
30
|
+
# @param providers [Array<String>, nil] ONNX execution providers
|
|
31
|
+
# @param show_progress [Boolean] Whether to show download progress
|
|
32
|
+
# @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
|
|
33
|
+
def initialize_model(model_name:, cache_dir:, threads:, providers:, show_progress:, quantization: nil)
|
|
34
|
+
@model_name = model_name
|
|
35
|
+
@threads = threads
|
|
36
|
+
@providers = providers
|
|
37
|
+
@quantization = quantization || Quantization::DEFAULT
|
|
38
|
+
|
|
39
|
+
validate_quantization!
|
|
40
|
+
|
|
41
|
+
ModelManagement.cache_dir = cache_dir if cache_dir
|
|
42
|
+
|
|
43
|
+
@model_info = resolve_model_info(model_name)
|
|
44
|
+
@model_dir = retrieve_model(model_name, show_progress: show_progress)
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Validate that the quantization type is supported
|
|
48
|
+
# @raise [ArgumentError] If quantization type is invalid
|
|
49
|
+
def validate_quantization!
|
|
50
|
+
return if Quantization.valid?(@quantization)
|
|
51
|
+
|
|
52
|
+
valid_types = Quantization::TYPES.keys.join(', ')
|
|
53
|
+
raise ArgumentError, "Invalid quantization type: #{@quantization}. Valid types: #{valid_types}"
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
# Get the model file path, accounting for quantization
|
|
57
|
+
# @return [String] Path to quantized model file (or base if fp32)
|
|
58
|
+
def quantized_model_file
|
|
59
|
+
Quantization.model_file(@model_info.model_file, @quantization)
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Override in subclasses to resolve from appropriate registry
|
|
63
|
+
#
|
|
64
|
+
# @param _model_name [String] Name of the model
|
|
65
|
+
# @return [BaseModelInfo] Model information object
|
|
66
|
+
# @raise [NotImplementedError] If not overridden in subclass
|
|
67
|
+
# @abstract
|
|
68
|
+
def resolve_model_info(_model_name)
|
|
69
|
+
raise NotImplementedError, 'Subclasses must implement resolve_model_info'
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
# Download or retrieve cached model
|
|
73
|
+
#
|
|
74
|
+
# @param model_name [String] Name of the model
|
|
75
|
+
# @param show_progress [Boolean] Whether to show download progress
|
|
76
|
+
# @return [String] Path to model directory
|
|
77
|
+
def retrieve_model(model_name, show_progress:)
|
|
78
|
+
ModelManagement.retrieve_model(
|
|
79
|
+
model_name,
|
|
80
|
+
model_info: @model_info,
|
|
81
|
+
show_progress: show_progress
|
|
82
|
+
)
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
# Build ONNX session options hash
|
|
86
|
+
# @return [Hash] Options for OnnxRuntime::InferenceSession
|
|
87
|
+
def build_session_options
|
|
88
|
+
options = {}
|
|
89
|
+
options[:inter_op_num_threads] = @threads if @threads
|
|
90
|
+
options[:intra_op_num_threads] = @threads if @threads
|
|
91
|
+
options
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Load an ONNX inference session
|
|
95
|
+
#
|
|
96
|
+
# @param model_path [String] Path to ONNX model file
|
|
97
|
+
# @param providers [Array<String>, nil] Execution providers
|
|
98
|
+
# @return [OnnxRuntime::InferenceSession] The loaded session
|
|
99
|
+
# @raise [Error] If model file not found
|
|
100
|
+
def load_onnx_session(model_path, providers: nil)
|
|
101
|
+
raise Error, "Model file not found: #{model_path}" unless File.exist?(model_path)
|
|
102
|
+
|
|
103
|
+
OnnxRuntime::InferenceSession.new(
|
|
104
|
+
model_path,
|
|
105
|
+
**build_session_options,
|
|
106
|
+
providers: providers || ['CPUExecutionProvider']
|
|
107
|
+
)
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# Load a HuggingFace tokenizer from file
|
|
111
|
+
#
|
|
112
|
+
# @param tokenizer_path [String] Path to tokenizer.json
|
|
113
|
+
# @param max_length [Integer] Maximum sequence length for truncation
|
|
114
|
+
# @return [Tokenizers::Tokenizer] Configured tokenizer
|
|
115
|
+
# @raise [Error] If tokenizer file not found
|
|
116
|
+
def load_tokenizer_from_file(tokenizer_path, max_length:)
|
|
117
|
+
raise Error, "Tokenizer not found: #{tokenizer_path}" unless File.exist?(tokenizer_path)
|
|
118
|
+
|
|
119
|
+
tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_path)
|
|
120
|
+
tokenizer.enable_padding(pad_id: 0, pad_token: '[PAD]')
|
|
121
|
+
tokenizer.enable_truncation(max_length)
|
|
122
|
+
tokenizer
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
# Load both ONNX model and tokenizer using model_info paths
|
|
126
|
+
#
|
|
127
|
+
# Sets @session and @tokenizer instance variables.
|
|
128
|
+
# Uses @model_dir and @model_info which must be set first.
|
|
129
|
+
#
|
|
130
|
+
# @param model_file_override [String, nil] Override model file path
|
|
131
|
+
# @return [void]
|
|
132
|
+
def setup_model_and_tokenizer(model_file_override: nil)
|
|
133
|
+
model_file = model_file_override || @model_info.model_file
|
|
134
|
+
model_path = File.join(@model_dir, model_file)
|
|
135
|
+
@session = load_onnx_session(model_path, providers: @providers)
|
|
136
|
+
|
|
137
|
+
tokenizer_path = File.join(@model_dir, @model_info.tokenizer_file)
|
|
138
|
+
@tokenizer = load_tokenizer_from_file(tokenizer_path, max_length: @model_info.max_length)
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
# Get input names from the ONNX session
|
|
142
|
+
#
|
|
143
|
+
# @return [Array<String>] List of input tensor names
|
|
144
|
+
def session_input_names
|
|
145
|
+
@session_input_names ||= @session.inputs.map { |i| i[:name] }
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
# Prepare model inputs from tokenizer encodings
|
|
149
|
+
#
|
|
150
|
+
# @param encodings [Array<Tokenizers::Encoding>] Batch of tokenizer encodings
|
|
151
|
+
# @return [Hash] Input tensors for ONNX session
|
|
152
|
+
def prepare_model_inputs(encodings)
|
|
153
|
+
input_ids = encodings.map(&:ids)
|
|
154
|
+
attention_mask = encodings.map(&:attention_mask)
|
|
155
|
+
|
|
156
|
+
inputs = { 'input_ids' => input_ids }
|
|
157
|
+
|
|
158
|
+
# Add attention_mask/input_mask if the model expects it
|
|
159
|
+
# (SPLADE models use input_mask, most BERT models use attention_mask)
|
|
160
|
+
mask_key = if session_input_names.include?('attention_mask')
|
|
161
|
+
'attention_mask'
|
|
162
|
+
elsif session_input_names.include?('input_mask')
|
|
163
|
+
'input_mask'
|
|
164
|
+
end
|
|
165
|
+
|
|
166
|
+
inputs[mask_key] = attention_mask if mask_key
|
|
167
|
+
|
|
168
|
+
# Add token_type_ids/segment_ids if the model expects it
|
|
169
|
+
# (SPLADE models use segment_ids, BERT models use token_type_ids)
|
|
170
|
+
token_type_key = if session_input_names.include?('token_type_ids')
|
|
171
|
+
'token_type_ids'
|
|
172
|
+
elsif session_input_names.include?('segment_ids')
|
|
173
|
+
'segment_ids'
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
if token_type_key
|
|
177
|
+
token_type_ids = encodings.map { |e| e.type_ids || Array.new(e.ids.length, 0) }
|
|
178
|
+
inputs[token_type_key] = token_type_ids
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
inputs
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
# Tokenize texts and prepare inputs for the model
|
|
185
|
+
#
|
|
186
|
+
# @param texts [Array<String>] Texts to tokenize
|
|
187
|
+
# @return [Hash] Hash with :inputs and :attention_mask keys
|
|
188
|
+
def tokenize_and_prepare(texts)
|
|
189
|
+
encodings = @tokenizer.encode_batch(texts)
|
|
190
|
+
inputs = prepare_model_inputs(encodings)
|
|
191
|
+
{ inputs: inputs, attention_mask: encodings.map(&:attention_mask) }
|
|
192
|
+
end
|
|
193
|
+
|
|
194
|
+
# Initialize model from local directory instead of downloading
|
|
195
|
+
#
|
|
196
|
+
# @param local_model_dir [String] Path to local model directory
|
|
197
|
+
# @param model_name [String] Name identifier for the model
|
|
198
|
+
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
199
|
+
# @param providers [Array<String>, nil] ONNX execution providers
|
|
200
|
+
# @param quantization [Symbol, nil] Quantization type
|
|
201
|
+
# @param model_file [String, nil] Override model file name
|
|
202
|
+
# @param tokenizer_file [String, nil] Override tokenizer file name
|
|
203
|
+
# @raise [ArgumentError] If local_model_dir doesn't exist
|
|
204
|
+
def initialize_from_local(local_model_dir:, model_name:, threads:, providers:, quantization:, model_file:,
|
|
205
|
+
tokenizer_file:)
|
|
206
|
+
raise ArgumentError, "Local model directory not found: #{local_model_dir}" unless Dir.exist?(local_model_dir)
|
|
207
|
+
|
|
208
|
+
@model_name = model_name
|
|
209
|
+
@threads = threads
|
|
210
|
+
@providers = providers
|
|
211
|
+
@quantization = quantization || Quantization::DEFAULT
|
|
212
|
+
@model_dir = local_model_dir
|
|
213
|
+
|
|
214
|
+
validate_quantization!
|
|
215
|
+
|
|
216
|
+
# Try to get model info from registry, or create a minimal one
|
|
217
|
+
# Subclasses should override create_local_model_info
|
|
218
|
+
@model_info = resolve_model_info_safe(model_name) || create_local_model_info(
|
|
219
|
+
model_name: model_name,
|
|
220
|
+
model_file: model_file,
|
|
221
|
+
tokenizer_file: tokenizer_file
|
|
222
|
+
)
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
# Safely try to resolve model info, returning nil if not found
|
|
226
|
+
#
|
|
227
|
+
# @param model_name [String] Model name to look up
|
|
228
|
+
# @return [BaseModelInfo, nil] Model info or nil
|
|
229
|
+
def resolve_model_info_safe(model_name)
|
|
230
|
+
resolve_model_info(model_name)
|
|
231
|
+
rescue StandardError
|
|
232
|
+
nil
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
# Create model info for locally loaded model
|
|
236
|
+
# Subclasses should override this to create appropriate model info type
|
|
237
|
+
#
|
|
238
|
+
# @param model_name [String] Model name identifier
|
|
239
|
+
# @param model_file [String, nil] Model file path
|
|
240
|
+
# @param tokenizer_file [String, nil] Tokenizer file path
|
|
241
|
+
# @return [BaseModelInfo] Model info object
|
|
242
|
+
# @abstract
|
|
243
|
+
def create_local_model_info(model_name:, model_file:, tokenizer_file:)
|
|
244
|
+
raise NotImplementedError, 'Subclasses must implement create_local_model_info for local model loading'
|
|
245
|
+
end
|
|
246
|
+
end
|
|
247
|
+
end
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Shared functionality for model information classes
|
|
5
|
+
#
|
|
6
|
+
# This module provides common attributes and methods for all model info types
|
|
7
|
+
# (embedding, reranking, sparse, late interaction). It handles metadata like
|
|
8
|
+
# model name, file paths, and HuggingFace source information.
|
|
9
|
+
#
|
|
10
|
+
# @abstract Include in model info classes and call {#initialize_base}
|
|
11
|
+
#
|
|
12
|
+
module BaseModelInfo
|
|
13
|
+
# Default maximum sequence length for tokenization
|
|
14
|
+
DEFAULT_MAX_LENGTH = 512
|
|
15
|
+
|
|
16
|
+
# @!attribute [r] model_name
|
|
17
|
+
# @return [String] Full model identifier (e.g., "BAAI/bge-small-en-v1.5")
|
|
18
|
+
# @!attribute [r] description
|
|
19
|
+
# @return [String] Human-readable model description
|
|
20
|
+
# @!attribute [r] size_in_gb
|
|
21
|
+
# @return [Float] Approximate model size in gigabytes
|
|
22
|
+
# @!attribute [r] model_file
|
|
23
|
+
# @return [String] Relative path to ONNX model file
|
|
24
|
+
# @!attribute [r] tokenizer_file
|
|
25
|
+
# @return [String] Relative path to tokenizer JSON file
|
|
26
|
+
# @!attribute [r] sources
|
|
27
|
+
# @return [Hash] Source repositories (e.g., {hf: "Xenova/model-name"})
|
|
28
|
+
# @!attribute [r] max_length
|
|
29
|
+
# @return [Integer] Maximum token sequence length
|
|
30
|
+
attr_reader :model_name, :description, :size_in_gb, :model_file,
|
|
31
|
+
:tokenizer_file, :sources, :max_length
|
|
32
|
+
|
|
33
|
+
# Returns the HuggingFace repository ID
|
|
34
|
+
# @return [String] The HF repo ID for downloading
|
|
35
|
+
def hf_repo
|
|
36
|
+
sources[:hf]
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
private
|
|
40
|
+
|
|
41
|
+
# Initialize common model info attributes
|
|
42
|
+
#
|
|
43
|
+
# @param model_name [String] Full model identifier
|
|
44
|
+
# @param description [String] Human-readable description
|
|
45
|
+
# @param size_in_gb [Float] Model size in GB
|
|
46
|
+
# @param sources [Hash] Source repositories
|
|
47
|
+
# @param model_file [String] Path to ONNX model file
|
|
48
|
+
# @param tokenizer_file [String] Path to tokenizer file
|
|
49
|
+
# @param max_length [Integer] Maximum sequence length
|
|
50
|
+
def initialize_base(model_name:, description:, size_in_gb:, sources:, model_file:, tokenizer_file:,
|
|
51
|
+
max_length: DEFAULT_MAX_LENGTH)
|
|
52
|
+
@model_name = model_name
|
|
53
|
+
@description = description
|
|
54
|
+
@size_in_gb = size_in_gb
|
|
55
|
+
@sources = sources
|
|
56
|
+
@model_file = model_file
|
|
57
|
+
@tokenizer_file = tokenizer_file
|
|
58
|
+
@max_length = max_length
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
end
|