fastembed 1.0.0 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +1 -0
- data/.yardopts +6 -0
- data/BENCHMARKS.md +124 -1
- data/CHANGELOG.md +14 -0
- data/README.md +395 -74
- data/benchmark/compare_all.rb +167 -0
- data/benchmark/compare_python.py +60 -0
- data/benchmark/memory_profile.rb +70 -0
- data/benchmark/profile.rb +198 -0
- data/benchmark/reranker_benchmark.rb +158 -0
- data/exe/fastembed +6 -0
- data/fastembed.gemspec +3 -0
- data/lib/fastembed/async.rb +193 -0
- data/lib/fastembed/base_model.rb +247 -0
- data/lib/fastembed/base_model_info.rb +61 -0
- data/lib/fastembed/cli.rb +745 -0
- data/lib/fastembed/custom_model_registry.rb +255 -0
- data/lib/fastembed/image_embedding.rb +313 -0
- data/lib/fastembed/late_interaction_embedding.rb +260 -0
- data/lib/fastembed/late_interaction_model_info.rb +91 -0
- data/lib/fastembed/model_info.rb +59 -19
- data/lib/fastembed/model_management.rb +82 -23
- data/lib/fastembed/onnx_embedding_model.rb +25 -4
- data/lib/fastembed/pooling.rb +39 -3
- data/lib/fastembed/progress.rb +52 -0
- data/lib/fastembed/quantization.rb +75 -0
- data/lib/fastembed/reranker_model_info.rb +91 -0
- data/lib/fastembed/sparse_embedding.rb +261 -0
- data/lib/fastembed/sparse_model_info.rb +80 -0
- data/lib/fastembed/text_cross_encoder.rb +217 -0
- data/lib/fastembed/text_embedding.rb +161 -28
- data/lib/fastembed/validators.rb +59 -0
- data/lib/fastembed/version.rb +1 -1
- data/lib/fastembed.rb +42 -1
- data/plan.md +257 -0
- data/scripts/verify_models.rb +229 -0
- metadata +70 -3
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Registry for custom user-defined models
|
|
5
|
+
#
|
|
6
|
+
# Allows users to register arbitrary ONNX models that aren't in the built-in registry.
|
|
7
|
+
# Custom models can be loaded from HuggingFace or local directories.
|
|
8
|
+
#
|
|
9
|
+
# @example Register a custom embedding model
|
|
10
|
+
# Fastembed.register_model(
|
|
11
|
+
# model_name: 'my-org/my-model',
|
|
12
|
+
# dim: 768,
|
|
13
|
+
# description: 'My custom model',
|
|
14
|
+
# sources: { hf: 'my-org/my-model-onnx' }
|
|
15
|
+
# )
|
|
16
|
+
# embed = Fastembed::TextEmbedding.new(model_name: 'my-org/my-model')
|
|
17
|
+
#
|
|
18
|
+
# @example Register a local model
|
|
19
|
+
# Fastembed.register_model(
|
|
20
|
+
# model_name: 'local-model',
|
|
21
|
+
# dim: 384,
|
|
22
|
+
# description: 'Local model',
|
|
23
|
+
# sources: {}
|
|
24
|
+
# )
|
|
25
|
+
# embed = Fastembed::TextEmbedding.new(
|
|
26
|
+
# model_name: 'local-model',
|
|
27
|
+
# local_model_dir: '/path/to/model'
|
|
28
|
+
# )
|
|
29
|
+
#
|
|
30
|
+
module CustomModelRegistry
|
|
31
|
+
class << self
|
|
32
|
+
# Custom embedding models registry
|
|
33
|
+
# @return [Hash<String, ModelInfo>]
|
|
34
|
+
def embedding_models
|
|
35
|
+
@embedding_models ||= {}
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Custom reranker models registry
|
|
39
|
+
# @return [Hash<String, RerankerModelInfo>]
|
|
40
|
+
def reranker_models
|
|
41
|
+
@reranker_models ||= {}
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
# Custom sparse models registry
|
|
45
|
+
# @return [Hash<String, SparseModelInfo>]
|
|
46
|
+
def sparse_models
|
|
47
|
+
@sparse_models ||= {}
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Custom late interaction models registry
|
|
51
|
+
# @return [Hash<String, LateInteractionModelInfo>]
|
|
52
|
+
def late_interaction_models
|
|
53
|
+
@late_interaction_models ||= {}
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
# Register a custom embedding model
|
|
57
|
+
#
|
|
58
|
+
# @param model_name [String] Unique model identifier
|
|
59
|
+
# @param dim [Integer] Output embedding dimension
|
|
60
|
+
# @param description [String] Human-readable description
|
|
61
|
+
# @param sources [Hash] Source repositories (e.g., { hf: 'org/repo' })
|
|
62
|
+
# @param size_in_gb [Float] Approximate model size
|
|
63
|
+
# @param model_file [String] Path to ONNX file within model directory
|
|
64
|
+
# @param tokenizer_file [String] Path to tokenizer.json
|
|
65
|
+
# @param pooling [Symbol] Pooling strategy (:mean or :cls)
|
|
66
|
+
# @param normalize [Boolean] Whether to L2 normalize outputs
|
|
67
|
+
# @param max_length [Integer] Maximum sequence length
|
|
68
|
+
# @return [ModelInfo] The registered model info
|
|
69
|
+
def register_embedding_model(
|
|
70
|
+
model_name:,
|
|
71
|
+
dim:,
|
|
72
|
+
description: 'Custom model',
|
|
73
|
+
sources: {},
|
|
74
|
+
size_in_gb: 0,
|
|
75
|
+
model_file: 'model.onnx',
|
|
76
|
+
tokenizer_file: 'tokenizer.json',
|
|
77
|
+
pooling: :mean,
|
|
78
|
+
normalize: true,
|
|
79
|
+
max_length: 512
|
|
80
|
+
)
|
|
81
|
+
embedding_models[model_name] = ModelInfo.new(
|
|
82
|
+
model_name: model_name,
|
|
83
|
+
dim: dim,
|
|
84
|
+
description: description,
|
|
85
|
+
sources: sources,
|
|
86
|
+
size_in_gb: size_in_gb,
|
|
87
|
+
model_file: model_file,
|
|
88
|
+
tokenizer_file: tokenizer_file,
|
|
89
|
+
pooling: pooling,
|
|
90
|
+
normalize: normalize,
|
|
91
|
+
max_length: max_length
|
|
92
|
+
)
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
# Register a custom reranker model
|
|
96
|
+
#
|
|
97
|
+
# @param model_name [String] Unique model identifier
|
|
98
|
+
# @param description [String] Human-readable description
|
|
99
|
+
# @param sources [Hash] Source repositories
|
|
100
|
+
# @param size_in_gb [Float] Approximate model size
|
|
101
|
+
# @param model_file [String] Path to ONNX file
|
|
102
|
+
# @param tokenizer_file [String] Path to tokenizer.json
|
|
103
|
+
# @return [RerankerModelInfo] The registered model info
|
|
104
|
+
def register_reranker_model(
|
|
105
|
+
model_name:,
|
|
106
|
+
description: 'Custom reranker',
|
|
107
|
+
sources: {},
|
|
108
|
+
size_in_gb: 0,
|
|
109
|
+
model_file: 'onnx/model.onnx',
|
|
110
|
+
tokenizer_file: 'tokenizer.json'
|
|
111
|
+
)
|
|
112
|
+
reranker_models[model_name] = RerankerModelInfo.new(
|
|
113
|
+
model_name: model_name,
|
|
114
|
+
description: description,
|
|
115
|
+
sources: sources,
|
|
116
|
+
size_in_gb: size_in_gb,
|
|
117
|
+
model_file: model_file,
|
|
118
|
+
tokenizer_file: tokenizer_file
|
|
119
|
+
)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
# Register a custom sparse embedding model
|
|
123
|
+
#
|
|
124
|
+
# @param model_name [String] Unique model identifier
|
|
125
|
+
# @param description [String] Human-readable description
|
|
126
|
+
# @param sources [Hash] Source repositories
|
|
127
|
+
# @param size_in_gb [Float] Approximate model size
|
|
128
|
+
# @param model_file [String] Path to ONNX file
|
|
129
|
+
# @param tokenizer_file [String] Path to tokenizer.json
|
|
130
|
+
# @param max_length [Integer] Maximum sequence length
|
|
131
|
+
# @return [SparseModelInfo] The registered model info
|
|
132
|
+
def register_sparse_model(
|
|
133
|
+
model_name:,
|
|
134
|
+
description: 'Custom sparse model',
|
|
135
|
+
sources: {},
|
|
136
|
+
size_in_gb: 0,
|
|
137
|
+
model_file: 'onnx/model.onnx',
|
|
138
|
+
tokenizer_file: 'tokenizer.json',
|
|
139
|
+
max_length: 512
|
|
140
|
+
)
|
|
141
|
+
sparse_models[model_name] = SparseModelInfo.new(
|
|
142
|
+
model_name: model_name,
|
|
143
|
+
description: description,
|
|
144
|
+
sources: sources,
|
|
145
|
+
size_in_gb: size_in_gb,
|
|
146
|
+
model_file: model_file,
|
|
147
|
+
tokenizer_file: tokenizer_file,
|
|
148
|
+
max_length: max_length
|
|
149
|
+
)
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
# Register a custom late interaction model
|
|
153
|
+
#
|
|
154
|
+
# @param model_name [String] Unique model identifier
|
|
155
|
+
# @param dim [Integer] Output embedding dimension per token
|
|
156
|
+
# @param description [String] Human-readable description
|
|
157
|
+
# @param sources [Hash] Source repositories
|
|
158
|
+
# @param size_in_gb [Float] Approximate model size
|
|
159
|
+
# @param model_file [String] Path to ONNX file
|
|
160
|
+
# @param tokenizer_file [String] Path to tokenizer.json
|
|
161
|
+
# @param max_length [Integer] Maximum sequence length
|
|
162
|
+
# @return [LateInteractionModelInfo] The registered model info
|
|
163
|
+
def register_late_interaction_model(
|
|
164
|
+
model_name:,
|
|
165
|
+
dim:,
|
|
166
|
+
description: 'Custom late interaction model',
|
|
167
|
+
sources: {},
|
|
168
|
+
size_in_gb: 0,
|
|
169
|
+
model_file: 'onnx/model.onnx',
|
|
170
|
+
tokenizer_file: 'tokenizer.json',
|
|
171
|
+
max_length: 512
|
|
172
|
+
)
|
|
173
|
+
late_interaction_models[model_name] = LateInteractionModelInfo.new(
|
|
174
|
+
model_name: model_name,
|
|
175
|
+
dim: dim,
|
|
176
|
+
description: description,
|
|
177
|
+
sources: sources,
|
|
178
|
+
size_in_gb: size_in_gb,
|
|
179
|
+
model_file: model_file,
|
|
180
|
+
tokenizer_file: tokenizer_file,
|
|
181
|
+
max_length: max_length
|
|
182
|
+
)
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
# Unregister a custom model
|
|
186
|
+
#
|
|
187
|
+
# @param model_name [String] Model to unregister
|
|
188
|
+
# @param type [Symbol] Model type (:embedding, :reranker, :sparse, :late_interaction)
|
|
189
|
+
# @return [Boolean] True if model was removed
|
|
190
|
+
def unregister_model(model_name, type: :embedding)
|
|
191
|
+
registry = case type
|
|
192
|
+
when :embedding then embedding_models
|
|
193
|
+
when :reranker then reranker_models
|
|
194
|
+
when :sparse then sparse_models
|
|
195
|
+
when :late_interaction then late_interaction_models
|
|
196
|
+
else raise ArgumentError, "Unknown model type: #{type}"
|
|
197
|
+
end
|
|
198
|
+
!registry.delete(model_name).nil?
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
# Clear all custom models
|
|
202
|
+
# @return [void]
|
|
203
|
+
def clear_all
|
|
204
|
+
@embedding_models = {}
|
|
205
|
+
@reranker_models = {}
|
|
206
|
+
@sparse_models = {}
|
|
207
|
+
@late_interaction_models = {}
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
# List all custom models
|
|
211
|
+
# @return [Hash] All custom models by type
|
|
212
|
+
def list_all
|
|
213
|
+
{
|
|
214
|
+
embedding: embedding_models.keys,
|
|
215
|
+
reranker: reranker_models.keys,
|
|
216
|
+
sparse: sparse_models.keys,
|
|
217
|
+
late_interaction: late_interaction_models.keys
|
|
218
|
+
}
|
|
219
|
+
end
|
|
220
|
+
end
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
# Convenience methods on the Fastembed module
|
|
224
|
+
class << self
|
|
225
|
+
# Register a custom embedding model
|
|
226
|
+
# @see CustomModelRegistry.register_embedding_model
|
|
227
|
+
def register_model(**)
|
|
228
|
+
CustomModelRegistry.register_embedding_model(**)
|
|
229
|
+
end
|
|
230
|
+
|
|
231
|
+
# Register a custom reranker model
|
|
232
|
+
# @see CustomModelRegistry.register_reranker_model
|
|
233
|
+
def register_reranker(**)
|
|
234
|
+
CustomModelRegistry.register_reranker_model(**)
|
|
235
|
+
end
|
|
236
|
+
|
|
237
|
+
# Register a custom sparse model
|
|
238
|
+
# @see CustomModelRegistry.register_sparse_model
|
|
239
|
+
def register_sparse_model(**)
|
|
240
|
+
CustomModelRegistry.register_sparse_model(**)
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
# Register a custom late interaction model
|
|
244
|
+
# @see CustomModelRegistry.register_late_interaction_model
|
|
245
|
+
def register_late_interaction_model(**)
|
|
246
|
+
CustomModelRegistry.register_late_interaction_model(**)
|
|
247
|
+
end
|
|
248
|
+
|
|
249
|
+
# List all custom registered models
|
|
250
|
+
# @return [Hash] Custom models by type
|
|
251
|
+
def custom_models
|
|
252
|
+
CustomModelRegistry.list_all
|
|
253
|
+
end
|
|
254
|
+
end
|
|
255
|
+
end
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fastembed
|
|
4
|
+
# Model information for image embedding models
|
|
5
|
+
class ImageModelInfo
|
|
6
|
+
include BaseModelInfo
|
|
7
|
+
|
|
8
|
+
attr_reader :dim, :image_size, :mean, :std
|
|
9
|
+
|
|
10
|
+
def initialize(
|
|
11
|
+
model_name:,
|
|
12
|
+
dim:,
|
|
13
|
+
description:,
|
|
14
|
+
size_in_gb:,
|
|
15
|
+
sources:,
|
|
16
|
+
model_file: 'model.onnx',
|
|
17
|
+
image_size: 224,
|
|
18
|
+
mean: [0.485, 0.456, 0.406],
|
|
19
|
+
std: [0.229, 0.224, 0.225]
|
|
20
|
+
)
|
|
21
|
+
initialize_base(
|
|
22
|
+
model_name: model_name,
|
|
23
|
+
description: description,
|
|
24
|
+
size_in_gb: size_in_gb,
|
|
25
|
+
sources: sources,
|
|
26
|
+
model_file: model_file,
|
|
27
|
+
tokenizer_file: '',
|
|
28
|
+
max_length: 0
|
|
29
|
+
)
|
|
30
|
+
@dim = dim
|
|
31
|
+
@image_size = image_size
|
|
32
|
+
@mean = mean
|
|
33
|
+
@std = std
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def to_h
|
|
37
|
+
{
|
|
38
|
+
model_name: model_name,
|
|
39
|
+
dim: dim,
|
|
40
|
+
description: description,
|
|
41
|
+
size_in_gb: size_in_gb,
|
|
42
|
+
sources: sources,
|
|
43
|
+
model_file: model_file,
|
|
44
|
+
image_size: image_size,
|
|
45
|
+
mean: mean,
|
|
46
|
+
std: std
|
|
47
|
+
}
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
# Registry of supported image embedding models
|
|
52
|
+
SUPPORTED_IMAGE_MODELS = {
|
|
53
|
+
'Qdrant/clip-ViT-B-32-vision' => ImageModelInfo.new(
|
|
54
|
+
model_name: 'Qdrant/clip-ViT-B-32-vision',
|
|
55
|
+
dim: 512,
|
|
56
|
+
description: 'CLIP ViT-B/32 vision encoder',
|
|
57
|
+
size_in_gb: 0.34,
|
|
58
|
+
sources: { hf: 'Qdrant/clip-ViT-B-32-vision' },
|
|
59
|
+
model_file: 'model.onnx',
|
|
60
|
+
image_size: 224
|
|
61
|
+
),
|
|
62
|
+
'Qdrant/resnet50-onnx' => ImageModelInfo.new(
|
|
63
|
+
model_name: 'Qdrant/resnet50-onnx',
|
|
64
|
+
dim: 2048,
|
|
65
|
+
description: 'ResNet-50 image encoder',
|
|
66
|
+
size_in_gb: 0.10,
|
|
67
|
+
sources: { hf: 'Qdrant/resnet50-onnx' },
|
|
68
|
+
model_file: 'model.onnx',
|
|
69
|
+
image_size: 224
|
|
70
|
+
),
|
|
71
|
+
'jinaai/jina-clip-v1' => ImageModelInfo.new(
|
|
72
|
+
model_name: 'jinaai/jina-clip-v1',
|
|
73
|
+
dim: 768,
|
|
74
|
+
description: 'Jina CLIP v1 vision encoder',
|
|
75
|
+
size_in_gb: 0.35,
|
|
76
|
+
sources: { hf: 'jinaai/jina-clip-v1' },
|
|
77
|
+
model_file: 'onnx/vision_model.onnx',
|
|
78
|
+
image_size: 224
|
|
79
|
+
)
|
|
80
|
+
}.freeze
|
|
81
|
+
|
|
82
|
+
DEFAULT_IMAGE_MODEL = 'Qdrant/clip-ViT-B-32-vision'
|
|
83
|
+
|
|
84
|
+
# Image embedding model for converting images to vectors
|
|
85
|
+
#
|
|
86
|
+
# Supports CLIP and ResNet models for image search and multimodal applications.
|
|
87
|
+
# Requires the mini_magick gem for image processing.
|
|
88
|
+
#
|
|
89
|
+
# @example Basic usage
|
|
90
|
+
# image_embed = Fastembed::ImageEmbedding.new
|
|
91
|
+
# vectors = image_embed.embed(["path/to/image.jpg"]).to_a
|
|
92
|
+
#
|
|
93
|
+
# @example With URLs
|
|
94
|
+
# vectors = image_embed.embed(["https://example.com/image.jpg"]).to_a
|
|
95
|
+
#
|
|
96
|
+
class ImageEmbedding
|
|
97
|
+
attr_reader :model_name, :model_info, :dim
|
|
98
|
+
|
|
99
|
+
# Initialize an image embedding model
|
|
100
|
+
#
|
|
101
|
+
# @param model_name [String] Name of the model to use
|
|
102
|
+
# @param cache_dir [String, nil] Custom cache directory for models
|
|
103
|
+
# @param threads [Integer, nil] Number of threads for ONNX Runtime
|
|
104
|
+
# @param providers [Array<String>, nil] ONNX execution providers
|
|
105
|
+
# @param show_progress [Boolean] Whether to show download progress
|
|
106
|
+
# @param local_model_dir [String, nil] Load model from local directory instead of downloading
|
|
107
|
+
# @param model_file [String, nil] Override model file name (e.g., "model.onnx")
|
|
108
|
+
def initialize(
|
|
109
|
+
model_name: DEFAULT_IMAGE_MODEL,
|
|
110
|
+
cache_dir: nil,
|
|
111
|
+
threads: nil,
|
|
112
|
+
providers: nil,
|
|
113
|
+
show_progress: true,
|
|
114
|
+
local_model_dir: nil,
|
|
115
|
+
model_file: nil
|
|
116
|
+
)
|
|
117
|
+
require_mini_magick!
|
|
118
|
+
|
|
119
|
+
@model_name = model_name
|
|
120
|
+
@threads = threads
|
|
121
|
+
@providers = providers || ['CPUExecutionProvider']
|
|
122
|
+
@model_file_override = model_file
|
|
123
|
+
|
|
124
|
+
if local_model_dir
|
|
125
|
+
initialize_from_local(local_model_dir: local_model_dir, model_name: model_name, model_file: model_file)
|
|
126
|
+
else
|
|
127
|
+
ModelManagement.cache_dir = cache_dir if cache_dir
|
|
128
|
+
@model_info = resolve_model_info(model_name)
|
|
129
|
+
@model_dir = retrieve_model(model_name, show_progress: show_progress)
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
@dim = @model_info.dim
|
|
133
|
+
setup_model
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
# Generate embeddings for images
|
|
137
|
+
#
|
|
138
|
+
# @param images [Array<String>, String] Image path(s) or URL(s) to embed
|
|
139
|
+
# @param batch_size [Integer] Number of images to process at once
|
|
140
|
+
# @yield [Progress] Optional progress callback called after each batch
|
|
141
|
+
# @return [Enumerator] Lazy enumerator yielding embedding vectors
|
|
142
|
+
def embed(images, batch_size: 32, &progress_callback)
|
|
143
|
+
images = [images] if images.is_a?(String)
|
|
144
|
+
return Enumerator.new { |_| } if images.empty?
|
|
145
|
+
|
|
146
|
+
total_batches = (images.length.to_f / batch_size).ceil
|
|
147
|
+
|
|
148
|
+
Enumerator.new do |yielder|
|
|
149
|
+
images.each_slice(batch_size).with_index(1) do |batch, batch_num|
|
|
150
|
+
embeddings = compute_embeddings(batch)
|
|
151
|
+
embeddings.each { |emb| yielder << emb }
|
|
152
|
+
|
|
153
|
+
if progress_callback
|
|
154
|
+
progress = Progress.new(current: batch_num, total: total_batches, batch_size: batch_size)
|
|
155
|
+
progress_callback.call(progress)
|
|
156
|
+
end
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
# Generate embeddings asynchronously
|
|
162
|
+
#
|
|
163
|
+
# @param images [Array<String>, String] Image path(s) or URL(s) to embed
|
|
164
|
+
# @param batch_size [Integer] Number of images to process at once
|
|
165
|
+
# @return [Async::Future] Future that resolves to array of embedding vectors
|
|
166
|
+
def embed_async(images, batch_size: 32)
|
|
167
|
+
Async::Future.new { embed(images, batch_size: batch_size).to_a }
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
# List all supported image models
|
|
171
|
+
#
|
|
172
|
+
# @return [Array<Hash>] Array of model information hashes
|
|
173
|
+
def self.list_supported_models
|
|
174
|
+
SUPPORTED_IMAGE_MODELS.values.map(&:to_h)
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
private
|
|
178
|
+
|
|
179
|
+
def require_mini_magick!
|
|
180
|
+
require 'mini_magick'
|
|
181
|
+
rescue LoadError
|
|
182
|
+
raise Error, 'Image embedding requires the mini_magick gem. Add it to your Gemfile: gem "mini_magick"'
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def resolve_model_info(model_name)
|
|
186
|
+
info = SUPPORTED_IMAGE_MODELS[model_name]
|
|
187
|
+
raise Error, "Unknown image model: #{model_name}" unless info
|
|
188
|
+
|
|
189
|
+
info
|
|
190
|
+
end
|
|
191
|
+
|
|
192
|
+
def initialize_from_local(local_model_dir:, model_name:, model_file:)
|
|
193
|
+
raise ArgumentError, "Local model directory not found: #{local_model_dir}" unless Dir.exist?(local_model_dir)
|
|
194
|
+
|
|
195
|
+
@model_dir = local_model_dir
|
|
196
|
+
@model_info = SUPPORTED_IMAGE_MODELS[model_name] || create_local_model_info(
|
|
197
|
+
model_name: model_name,
|
|
198
|
+
model_file: model_file
|
|
199
|
+
)
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
def create_local_model_info(model_name:, model_file:)
|
|
203
|
+
ImageModelInfo.new(
|
|
204
|
+
model_name: model_name,
|
|
205
|
+
dim: 512, # Default CLIP dimension
|
|
206
|
+
description: 'Local image model',
|
|
207
|
+
size_in_gb: 0,
|
|
208
|
+
sources: {},
|
|
209
|
+
model_file: model_file || 'model.onnx',
|
|
210
|
+
image_size: 224
|
|
211
|
+
)
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
def retrieve_model(model_name, show_progress:)
|
|
215
|
+
ModelManagement.retrieve_model(
|
|
216
|
+
model_name,
|
|
217
|
+
model_info: @model_info,
|
|
218
|
+
show_progress: show_progress
|
|
219
|
+
)
|
|
220
|
+
end
|
|
221
|
+
|
|
222
|
+
def setup_model
|
|
223
|
+
model_file = @model_file_override || @model_info.model_file
|
|
224
|
+
model_path = File.join(@model_dir, model_file)
|
|
225
|
+
raise Error, "Model file not found: #{model_path}" unless File.exist?(model_path)
|
|
226
|
+
|
|
227
|
+
options = {}
|
|
228
|
+
options[:inter_op_num_threads] = @threads if @threads
|
|
229
|
+
options[:intra_op_num_threads] = @threads if @threads
|
|
230
|
+
|
|
231
|
+
@session = OnnxRuntime::InferenceSession.new(
|
|
232
|
+
model_path,
|
|
233
|
+
**options,
|
|
234
|
+
providers: @providers
|
|
235
|
+
)
|
|
236
|
+
end
|
|
237
|
+
|
|
238
|
+
def compute_embeddings(image_paths)
|
|
239
|
+
# Preprocess images into tensor
|
|
240
|
+
tensors = image_paths.map { |path| preprocess_image(path) }
|
|
241
|
+
|
|
242
|
+
# Stack into batch [batch, channels, height, width]
|
|
243
|
+
batch_tensor = tensors
|
|
244
|
+
|
|
245
|
+
# Run inference
|
|
246
|
+
input_name = @session.inputs.first[:name]
|
|
247
|
+
outputs = @session.run(nil, { input_name => batch_tensor })
|
|
248
|
+
|
|
249
|
+
# Extract and normalize embeddings
|
|
250
|
+
embeddings = outputs.first
|
|
251
|
+
embeddings.map { |emb| normalize_embedding(emb) }
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
def preprocess_image(image_path)
|
|
255
|
+
# Load image
|
|
256
|
+
image = load_image(image_path)
|
|
257
|
+
|
|
258
|
+
# Resize to model's expected size
|
|
259
|
+
size = @model_info.image_size
|
|
260
|
+
image.resize "#{size}x#{size}!"
|
|
261
|
+
|
|
262
|
+
# Convert to RGB tensor and normalize
|
|
263
|
+
pixels = extract_pixels(image)
|
|
264
|
+
normalize_pixels(pixels)
|
|
265
|
+
end
|
|
266
|
+
|
|
267
|
+
def load_image(path)
|
|
268
|
+
raise Error, "Image file not found: #{path}" if !path.start_with?('http://', 'https://') && !File.exist?(path)
|
|
269
|
+
|
|
270
|
+
MiniMagick::Image.open(path)
|
|
271
|
+
end
|
|
272
|
+
|
|
273
|
+
def extract_pixels(image)
|
|
274
|
+
# Get raw RGB pixel data using ImageMagick's export
|
|
275
|
+
# depth:8 ensures 8-bit per channel, and 'RGB' gives us raw RGB bytes
|
|
276
|
+
pixels_str = image.run_command('convert', image.path, '-depth', '8', 'RGB:-')
|
|
277
|
+
|
|
278
|
+
# Convert to array of RGB values [0-255]
|
|
279
|
+
pixels_str.unpack('C*')
|
|
280
|
+
end
|
|
281
|
+
|
|
282
|
+
def normalize_pixels(pixels)
|
|
283
|
+
size = @model_info.image_size
|
|
284
|
+
mean = @model_info.mean
|
|
285
|
+
std = @model_info.std
|
|
286
|
+
|
|
287
|
+
# Convert from [H, W, C] flat array to [C, H, W] tensor
|
|
288
|
+
channels = 3
|
|
289
|
+
tensor = Array.new(channels) { Array.new(size) { Array.new(size) } }
|
|
290
|
+
|
|
291
|
+
pixels.each_with_index do |pixel, i|
|
|
292
|
+
h = (i / 3) / size
|
|
293
|
+
w = (i / 3) % size
|
|
294
|
+
c = i % 3
|
|
295
|
+
|
|
296
|
+
# Normalize: (pixel/255 - mean) / std
|
|
297
|
+
normalized = ((pixel / 255.0) - mean[c]) / std[c]
|
|
298
|
+
tensor[c][h][w] = normalized
|
|
299
|
+
end
|
|
300
|
+
|
|
301
|
+
tensor
|
|
302
|
+
end
|
|
303
|
+
|
|
304
|
+
def normalize_embedding(embedding)
|
|
305
|
+
# L2 normalize the embedding
|
|
306
|
+
embedding = embedding.flatten if embedding.is_a?(Array) && embedding.first.is_a?(Array)
|
|
307
|
+
norm = Math.sqrt(embedding.sum { |x| x * x })
|
|
308
|
+
return embedding if norm.zero?
|
|
309
|
+
|
|
310
|
+
embedding.map { |x| x / norm }
|
|
311
|
+
end
|
|
312
|
+
end
|
|
313
|
+
end
|