fastembed 1.0.0 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,193 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Async support for embedding operations
5
+ #
6
+ # Provides Future-like objects for running embeddings in background threads.
7
+ # Useful for parallelizing embedding generation across multiple documents.
8
+ #
9
+ # @example Async embedding
10
+ # embedding = Fastembed::TextEmbedding.new
11
+ # future = embedding.embed_async(documents)
12
+ # # ... do other work ...
13
+ # vectors = future.value # blocks until complete
14
+ #
15
+ # @example Multiple concurrent embeddings
16
+ # futures = documents.each_slice(100).map do |batch|
17
+ # embedding.embed_async(batch)
18
+ # end
19
+ # results = futures.flat_map(&:value)
20
+ #
21
+ module Async
22
+ # A Future representing an async embedding operation
23
+ #
24
+ # Wraps a background thread that performs embedding, providing
25
+ # methods to check completion status and retrieve results.
26
+ #
27
+ class Future
28
+ # @return [Thread] The background thread
29
+ attr_reader :thread
30
+
31
+ # Create a new Future
32
+ #
33
+ # @yield Block to execute in background thread
34
+ # @return [Future]
35
+ def initialize(&block)
36
+ @result = nil
37
+ @error = nil
38
+ @completed = false
39
+ @mutex = Mutex.new
40
+ @condition = ConditionVariable.new
41
+
42
+ @thread = Thread.new do
43
+ result = block.call
44
+ @mutex.synchronize do
45
+ @result = result
46
+ @completed = true
47
+ @condition.broadcast
48
+ end
49
+ rescue StandardError => e
50
+ @mutex.synchronize do
51
+ @error = e
52
+ @completed = true
53
+ @condition.broadcast
54
+ end
55
+ end
56
+ end
57
+
58
+ # Check if the operation is complete
59
+ #
60
+ # @return [Boolean] True if complete (success or failure)
61
+ def complete?
62
+ @mutex.synchronize { @completed }
63
+ end
64
+
65
+ alias completed? complete?
66
+
67
+ # Check if the operation is still running
68
+ #
69
+ # @return [Boolean] True if still running
70
+ def pending?
71
+ !complete?
72
+ end
73
+
74
+ # Check if the operation completed successfully
75
+ #
76
+ # @return [Boolean] True if completed without error
77
+ def success?
78
+ @mutex.synchronize { @completed && @error.nil? }
79
+ end
80
+
81
+ # Check if the operation failed
82
+ #
83
+ # @return [Boolean] True if completed with error
84
+ def failure?
85
+ @mutex.synchronize { @completed && !@error.nil? }
86
+ end
87
+
88
+ # Get the result, blocking until complete
89
+ #
90
+ # @param timeout [Numeric, nil] Maximum seconds to wait (nil = forever)
91
+ # @return [Object] The result of the async operation
92
+ # @raise [StandardError] If the operation raised an error
93
+ # @raise [Timeout::Error] If timeout expires before completion
94
+ def value(timeout: nil)
95
+ wait(timeout: timeout)
96
+ raise @error if @error
97
+
98
+ @result
99
+ end
100
+
101
+ alias result value
102
+
103
+ # Wait for completion without retrieving the result
104
+ #
105
+ # @param timeout [Numeric, nil] Maximum seconds to wait (nil = forever)
106
+ # @return [Boolean] True if completed, false if timed out
107
+ def wait(timeout: nil)
108
+ @mutex.synchronize do
109
+ return true if @completed
110
+
111
+ if timeout
112
+ deadline = Time.now + timeout
113
+ until @completed
114
+ remaining = deadline - Time.now
115
+ break if remaining <= 0
116
+
117
+ @condition.wait(@mutex, remaining)
118
+ end
119
+ else
120
+ @condition.wait(@mutex) until @completed
121
+ end
122
+
123
+ @completed
124
+ end
125
+ end
126
+
127
+ # Get the error if the operation failed
128
+ #
129
+ # @return [StandardError, nil] The error, or nil if successful/pending
130
+ def error
131
+ @mutex.synchronize { @error }
132
+ end
133
+
134
+ # Apply a transformation to the result
135
+ #
136
+ # @yield [result] Block to transform the result
137
+ # @return [Future] A new Future with the transformed result
138
+ def then(&block)
139
+ Future.new do
140
+ block.call(value)
141
+ end
142
+ end
143
+
144
+ # Handle errors
145
+ #
146
+ # @yield [error] Block to handle errors
147
+ # @return [Future] A new Future that handles errors
148
+ def rescue(&block)
149
+ Future.new do
150
+ value
151
+ rescue StandardError => e
152
+ block.call(e)
153
+ end
154
+ end
155
+ end
156
+
157
+ # Run multiple futures concurrently and wait for all to complete
158
+ #
159
+ # @param futures [Array<Future>] Futures to wait for
160
+ # @param timeout [Numeric, nil] Maximum seconds to wait
161
+ # @return [Array] Results from all futures
162
+ # @raise [StandardError] If any future raised an error
163
+ def self.all(futures, timeout: nil)
164
+ futures.each { |f| f.wait(timeout: timeout) }
165
+ futures.map(&:value)
166
+ end
167
+
168
+ # Run multiple futures concurrently and return first completed
169
+ #
170
+ # @param futures [Array<Future>] Futures to race
171
+ # @param timeout [Numeric, nil] Maximum seconds to wait
172
+ # @return [Object] Result from first completed future
173
+ # @raise [Timeout::Error] If timeout expires before any future completes
174
+ def self.race(futures, timeout: nil)
175
+ raise ArgumentError, 'No futures provided' if futures.empty?
176
+
177
+ deadline = timeout ? Time.now + timeout : nil
178
+ sleep_time = 0.001 # Start with 1ms
179
+
180
+ loop do
181
+ futures.each do |future|
182
+ return future.value if future.complete?
183
+ end
184
+
185
+ raise Timeout::Error, 'No future completed within timeout' if deadline && Time.now >= deadline
186
+
187
+ sleep sleep_time
188
+ # Exponential backoff up to 10ms to reduce CPU usage for long waits
189
+ sleep_time = [sleep_time * 1.5, 0.01].min
190
+ end
191
+ end
192
+ end
193
+ end
@@ -0,0 +1,247 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'onnxruntime'
4
+ require 'tokenizers'
5
+
6
+ module Fastembed
7
+ # Shared functionality for model classes
8
+ #
9
+ # This module provides common initialization and utility methods used by
10
+ # all model types (TextEmbedding, TextCrossEncoder, TextSparseEmbedding, etc.).
11
+ # It handles model downloading, ONNX session creation, and tokenizer loading.
12
+ #
13
+ # @abstract Include in model classes and call {#initialize_model}
14
+ #
15
+ module BaseModel
16
+ # @!attribute [r] model_name
17
+ # @return [String] Name of the loaded model
18
+ # @!attribute [r] model_info
19
+ # @return [BaseModelInfo] Model metadata and configuration
20
+ # @!attribute [r] quantization
21
+ # @return [Symbol] Current quantization type
22
+ attr_reader :model_name, :model_info, :quantization
23
+
24
+ private
25
+
26
+ # Common initialization logic for all model types
27
+ # @param model_name [String] Name of the model
28
+ # @param cache_dir [String, nil] Custom cache directory
29
+ # @param threads [Integer, nil] Number of threads for ONNX Runtime
30
+ # @param providers [Array<String>, nil] ONNX execution providers
31
+ # @param show_progress [Boolean] Whether to show download progress
32
+ # @param quantization [Symbol] Quantization type (:fp32, :fp16, :int8, :uint8, :q4)
33
+ def initialize_model(model_name:, cache_dir:, threads:, providers:, show_progress:, quantization: nil)
34
+ @model_name = model_name
35
+ @threads = threads
36
+ @providers = providers
37
+ @quantization = quantization || Quantization::DEFAULT
38
+
39
+ validate_quantization!
40
+
41
+ ModelManagement.cache_dir = cache_dir if cache_dir
42
+
43
+ @model_info = resolve_model_info(model_name)
44
+ @model_dir = retrieve_model(model_name, show_progress: show_progress)
45
+ end
46
+
47
+ # Validate that the quantization type is supported
48
+ # @raise [ArgumentError] If quantization type is invalid
49
+ def validate_quantization!
50
+ return if Quantization.valid?(@quantization)
51
+
52
+ valid_types = Quantization::TYPES.keys.join(', ')
53
+ raise ArgumentError, "Invalid quantization type: #{@quantization}. Valid types: #{valid_types}"
54
+ end
55
+
56
+ # Get the model file path, accounting for quantization
57
+ # @return [String] Path to quantized model file (or base if fp32)
58
+ def quantized_model_file
59
+ Quantization.model_file(@model_info.model_file, @quantization)
60
+ end
61
+
62
+ # Override in subclasses to resolve from appropriate registry
63
+ #
64
+ # @param _model_name [String] Name of the model
65
+ # @return [BaseModelInfo] Model information object
66
+ # @raise [NotImplementedError] If not overridden in subclass
67
+ # @abstract
68
+ def resolve_model_info(_model_name)
69
+ raise NotImplementedError, 'Subclasses must implement resolve_model_info'
70
+ end
71
+
72
+ # Download or retrieve cached model
73
+ #
74
+ # @param model_name [String] Name of the model
75
+ # @param show_progress [Boolean] Whether to show download progress
76
+ # @return [String] Path to model directory
77
+ def retrieve_model(model_name, show_progress:)
78
+ ModelManagement.retrieve_model(
79
+ model_name,
80
+ model_info: @model_info,
81
+ show_progress: show_progress
82
+ )
83
+ end
84
+
85
+ # Build ONNX session options hash
86
+ # @return [Hash] Options for OnnxRuntime::InferenceSession
87
+ def build_session_options
88
+ options = {}
89
+ options[:inter_op_num_threads] = @threads if @threads
90
+ options[:intra_op_num_threads] = @threads if @threads
91
+ options
92
+ end
93
+
94
+ # Load an ONNX inference session
95
+ #
96
+ # @param model_path [String] Path to ONNX model file
97
+ # @param providers [Array<String>, nil] Execution providers
98
+ # @return [OnnxRuntime::InferenceSession] The loaded session
99
+ # @raise [Error] If model file not found
100
+ def load_onnx_session(model_path, providers: nil)
101
+ raise Error, "Model file not found: #{model_path}" unless File.exist?(model_path)
102
+
103
+ OnnxRuntime::InferenceSession.new(
104
+ model_path,
105
+ **build_session_options,
106
+ providers: providers || ['CPUExecutionProvider']
107
+ )
108
+ end
109
+
110
+ # Load a HuggingFace tokenizer from file
111
+ #
112
+ # @param tokenizer_path [String] Path to tokenizer.json
113
+ # @param max_length [Integer] Maximum sequence length for truncation
114
+ # @return [Tokenizers::Tokenizer] Configured tokenizer
115
+ # @raise [Error] If tokenizer file not found
116
+ def load_tokenizer_from_file(tokenizer_path, max_length:)
117
+ raise Error, "Tokenizer not found: #{tokenizer_path}" unless File.exist?(tokenizer_path)
118
+
119
+ tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_path)
120
+ tokenizer.enable_padding(pad_id: 0, pad_token: '[PAD]')
121
+ tokenizer.enable_truncation(max_length)
122
+ tokenizer
123
+ end
124
+
125
+ # Load both ONNX model and tokenizer using model_info paths
126
+ #
127
+ # Sets @session and @tokenizer instance variables.
128
+ # Uses @model_dir and @model_info which must be set first.
129
+ #
130
+ # @param model_file_override [String, nil] Override model file path
131
+ # @return [void]
132
+ def setup_model_and_tokenizer(model_file_override: nil)
133
+ model_file = model_file_override || @model_info.model_file
134
+ model_path = File.join(@model_dir, model_file)
135
+ @session = load_onnx_session(model_path, providers: @providers)
136
+
137
+ tokenizer_path = File.join(@model_dir, @model_info.tokenizer_file)
138
+ @tokenizer = load_tokenizer_from_file(tokenizer_path, max_length: @model_info.max_length)
139
+ end
140
+
141
+ # Get input names from the ONNX session
142
+ #
143
+ # @return [Array<String>] List of input tensor names
144
+ def session_input_names
145
+ @session_input_names ||= @session.inputs.map { |i| i[:name] }
146
+ end
147
+
148
+ # Prepare model inputs from tokenizer encodings
149
+ #
150
+ # @param encodings [Array<Tokenizers::Encoding>] Batch of tokenizer encodings
151
+ # @return [Hash] Input tensors for ONNX session
152
+ def prepare_model_inputs(encodings)
153
+ input_ids = encodings.map(&:ids)
154
+ attention_mask = encodings.map(&:attention_mask)
155
+
156
+ inputs = { 'input_ids' => input_ids }
157
+
158
+ # Add attention_mask/input_mask if the model expects it
159
+ # (SPLADE models use input_mask, most BERT models use attention_mask)
160
+ mask_key = if session_input_names.include?('attention_mask')
161
+ 'attention_mask'
162
+ elsif session_input_names.include?('input_mask')
163
+ 'input_mask'
164
+ end
165
+
166
+ inputs[mask_key] = attention_mask if mask_key
167
+
168
+ # Add token_type_ids/segment_ids if the model expects it
169
+ # (SPLADE models use segment_ids, BERT models use token_type_ids)
170
+ token_type_key = if session_input_names.include?('token_type_ids')
171
+ 'token_type_ids'
172
+ elsif session_input_names.include?('segment_ids')
173
+ 'segment_ids'
174
+ end
175
+
176
+ if token_type_key
177
+ token_type_ids = encodings.map { |e| e.type_ids || Array.new(e.ids.length, 0) }
178
+ inputs[token_type_key] = token_type_ids
179
+ end
180
+
181
+ inputs
182
+ end
183
+
184
+ # Tokenize texts and prepare inputs for the model
185
+ #
186
+ # @param texts [Array<String>] Texts to tokenize
187
+ # @return [Hash] Hash with :inputs and :attention_mask keys
188
+ def tokenize_and_prepare(texts)
189
+ encodings = @tokenizer.encode_batch(texts)
190
+ inputs = prepare_model_inputs(encodings)
191
+ { inputs: inputs, attention_mask: encodings.map(&:attention_mask) }
192
+ end
193
+
194
+ # Initialize model from local directory instead of downloading
195
+ #
196
+ # @param local_model_dir [String] Path to local model directory
197
+ # @param model_name [String] Name identifier for the model
198
+ # @param threads [Integer, nil] Number of threads for ONNX Runtime
199
+ # @param providers [Array<String>, nil] ONNX execution providers
200
+ # @param quantization [Symbol, nil] Quantization type
201
+ # @param model_file [String, nil] Override model file name
202
+ # @param tokenizer_file [String, nil] Override tokenizer file name
203
+ # @raise [ArgumentError] If local_model_dir doesn't exist
204
+ def initialize_from_local(local_model_dir:, model_name:, threads:, providers:, quantization:, model_file:,
205
+ tokenizer_file:)
206
+ raise ArgumentError, "Local model directory not found: #{local_model_dir}" unless Dir.exist?(local_model_dir)
207
+
208
+ @model_name = model_name
209
+ @threads = threads
210
+ @providers = providers
211
+ @quantization = quantization || Quantization::DEFAULT
212
+ @model_dir = local_model_dir
213
+
214
+ validate_quantization!
215
+
216
+ # Try to get model info from registry, or create a minimal one
217
+ # Subclasses should override create_local_model_info
218
+ @model_info = resolve_model_info_safe(model_name) || create_local_model_info(
219
+ model_name: model_name,
220
+ model_file: model_file,
221
+ tokenizer_file: tokenizer_file
222
+ )
223
+ end
224
+
225
+ # Safely try to resolve model info, returning nil if not found
226
+ #
227
+ # @param model_name [String] Model name to look up
228
+ # @return [BaseModelInfo, nil] Model info or nil
229
+ def resolve_model_info_safe(model_name)
230
+ resolve_model_info(model_name)
231
+ rescue StandardError
232
+ nil
233
+ end
234
+
235
+ # Create model info for locally loaded model
236
+ # Subclasses should override this to create appropriate model info type
237
+ #
238
+ # @param model_name [String] Model name identifier
239
+ # @param model_file [String, nil] Model file path
240
+ # @param tokenizer_file [String, nil] Tokenizer file path
241
+ # @return [BaseModelInfo] Model info object
242
+ # @abstract
243
+ def create_local_model_info(model_name:, model_file:, tokenizer_file:)
244
+ raise NotImplementedError, 'Subclasses must implement create_local_model_info for local model loading'
245
+ end
246
+ end
247
+ end
@@ -0,0 +1,61 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fastembed
4
+ # Shared functionality for model information classes
5
+ #
6
+ # This module provides common attributes and methods for all model info types
7
+ # (embedding, reranking, sparse, late interaction). It handles metadata like
8
+ # model name, file paths, and HuggingFace source information.
9
+ #
10
+ # @abstract Include in model info classes and call {#initialize_base}
11
+ #
12
+ module BaseModelInfo
13
+ # Default maximum sequence length for tokenization
14
+ DEFAULT_MAX_LENGTH = 512
15
+
16
+ # @!attribute [r] model_name
17
+ # @return [String] Full model identifier (e.g., "BAAI/bge-small-en-v1.5")
18
+ # @!attribute [r] description
19
+ # @return [String] Human-readable model description
20
+ # @!attribute [r] size_in_gb
21
+ # @return [Float] Approximate model size in gigabytes
22
+ # @!attribute [r] model_file
23
+ # @return [String] Relative path to ONNX model file
24
+ # @!attribute [r] tokenizer_file
25
+ # @return [String] Relative path to tokenizer JSON file
26
+ # @!attribute [r] sources
27
+ # @return [Hash] Source repositories (e.g., {hf: "Xenova/model-name"})
28
+ # @!attribute [r] max_length
29
+ # @return [Integer] Maximum token sequence length
30
+ attr_reader :model_name, :description, :size_in_gb, :model_file,
31
+ :tokenizer_file, :sources, :max_length
32
+
33
+ # Returns the HuggingFace repository ID
34
+ # @return [String] The HF repo ID for downloading
35
+ def hf_repo
36
+ sources[:hf]
37
+ end
38
+
39
+ private
40
+
41
+ # Initialize common model info attributes
42
+ #
43
+ # @param model_name [String] Full model identifier
44
+ # @param description [String] Human-readable description
45
+ # @param size_in_gb [Float] Model size in GB
46
+ # @param sources [Hash] Source repositories
47
+ # @param model_file [String] Path to ONNX model file
48
+ # @param tokenizer_file [String] Path to tokenizer file
49
+ # @param max_length [Integer] Maximum sequence length
50
+ def initialize_base(model_name:, description:, size_in_gb:, sources:, model_file:, tokenizer_file:,
51
+ max_length: DEFAULT_MAX_LENGTH)
52
+ @model_name = model_name
53
+ @description = description
54
+ @size_in_gb = size_in_gb
55
+ @sources = sources
56
+ @model_file = model_file
57
+ @tokenizer_file = tokenizer_file
58
+ @max_length = max_length
59
+ end
60
+ end
61
+ end