geminize 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/.standard.yml +3 -0
- data/.yardopts +14 -0
- data/CHANGELOG.md +24 -0
- data/CODE_OF_CONDUCT.md +132 -0
- data/CONTRIBUTING.md +109 -0
- data/LICENSE.txt +21 -0
- data/README.md +423 -0
- data/Rakefile +10 -0
- data/examples/README.md +75 -0
- data/examples/configuration.rb +58 -0
- data/examples/embeddings.rb +195 -0
- data/examples/multimodal.rb +126 -0
- data/examples/rails_chat/README.md +69 -0
- data/examples/rails_chat/app/controllers/chat_controller.rb +26 -0
- data/examples/rails_chat/app/views/chat/index.html.erb +112 -0
- data/examples/rails_chat/config/routes.rb +8 -0
- data/examples/rails_initializer.rb +46 -0
- data/examples/system_instructions.rb +101 -0
- data/lib/geminize/chat.rb +98 -0
- data/lib/geminize/client.rb +318 -0
- data/lib/geminize/configuration.rb +98 -0
- data/lib/geminize/conversation_repository.rb +161 -0
- data/lib/geminize/conversation_service.rb +126 -0
- data/lib/geminize/embeddings.rb +145 -0
- data/lib/geminize/error_mapper.rb +96 -0
- data/lib/geminize/error_parser.rb +120 -0
- data/lib/geminize/errors.rb +185 -0
- data/lib/geminize/middleware/error_handler.rb +72 -0
- data/lib/geminize/model_info.rb +91 -0
- data/lib/geminize/models/chat_request.rb +186 -0
- data/lib/geminize/models/chat_response.rb +118 -0
- data/lib/geminize/models/content_request.rb +530 -0
- data/lib/geminize/models/content_response.rb +99 -0
- data/lib/geminize/models/conversation.rb +156 -0
- data/lib/geminize/models/embedding_request.rb +222 -0
- data/lib/geminize/models/embedding_response.rb +1064 -0
- data/lib/geminize/models/memory.rb +88 -0
- data/lib/geminize/models/message.rb +140 -0
- data/lib/geminize/models/model.rb +171 -0
- data/lib/geminize/models/model_list.rb +124 -0
- data/lib/geminize/models/stream_response.rb +99 -0
- data/lib/geminize/rails/app/controllers/concerns/geminize/controller.rb +105 -0
- data/lib/geminize/rails/app/helpers/geminize_helper.rb +125 -0
- data/lib/geminize/rails/controller_additions.rb +41 -0
- data/lib/geminize/rails/engine.rb +29 -0
- data/lib/geminize/rails/helper_additions.rb +37 -0
- data/lib/geminize/rails.rb +50 -0
- data/lib/geminize/railtie.rb +33 -0
- data/lib/geminize/request_builder.rb +57 -0
- data/lib/geminize/text_generation.rb +285 -0
- data/lib/geminize/validators.rb +150 -0
- data/lib/geminize/vector_utils.rb +164 -0
- data/lib/geminize/version.rb +5 -0
- data/lib/geminize.rb +527 -0
- data/lib/generators/geminize/install_generator.rb +22 -0
- data/lib/generators/geminize/templates/README +31 -0
- data/lib/generators/geminize/templates/initializer.rb +38 -0
- data/sig/geminize.rbs +4 -0
- metadata +218 -0
@@ -0,0 +1,1064 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Geminize
|
4
|
+
module Models
|
5
|
+
# Represents a response from the Gemini API for an embedding request
|
6
|
+
class EmbeddingResponse
|
7
|
+
# @return [Hash] The raw response data from the API
|
8
|
+
attr_reader :data
|
9
|
+
|
10
|
+
# @return [Hash, nil] Token counts for the request and response
|
11
|
+
attr_reader :usage
|
12
|
+
|
13
|
+
# Initialize a new embedding response
|
14
|
+
# @param data [Hash] The raw API response data
|
15
|
+
# @raise [Geminize::ValidationError] If the data doesn't contain valid embeddings
|
16
|
+
def initialize(data)
|
17
|
+
@data = data
|
18
|
+
validate!
|
19
|
+
parse_response
|
20
|
+
end
|
21
|
+
|
22
|
+
# Get the embedding values as a flat array
|
23
|
+
# @return [Array<Float>] The embedding values
|
24
|
+
def values
|
25
|
+
return @data["embedding"]["values"] if single?
|
26
|
+
nil
|
27
|
+
end
|
28
|
+
|
29
|
+
# Check if the response is a batch (multiple embeddings)
|
30
|
+
# @return [Boolean] True if the response contains multiple embeddings
|
31
|
+
def batch?
|
32
|
+
# Check if we have the 'embeddings' key with an array value
|
33
|
+
@data.has_key?("embeddings") && @data["embeddings"].is_a?(Array)
|
34
|
+
end
|
35
|
+
|
36
|
+
# Check if the response is a single embedding (not a batch)
|
37
|
+
# @return [Boolean] True if the response contains a single embedding
|
38
|
+
def single?
|
39
|
+
@data.has_key?("embedding") && @data["embedding"].is_a?(Hash) &&
|
40
|
+
@data["embedding"].has_key?("values")
|
41
|
+
end
|
42
|
+
|
43
|
+
# Get all embeddings as an array of arrays
|
44
|
+
# @return [Array<Array<Float>>] Array of embedding vectors
|
45
|
+
def embeddings
|
46
|
+
if single?
|
47
|
+
[values]
|
48
|
+
elsif batch?
|
49
|
+
@data["embeddings"].map { |emb| emb["values"] }
|
50
|
+
else
|
51
|
+
[]
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
# Get the size of each embedding (vector dimension)
|
56
|
+
# @return [Integer] The number of dimensions in each embedding
|
57
|
+
def embedding_size
|
58
|
+
if single?
|
59
|
+
values.size
|
60
|
+
elsif batch? && !@data["embeddings"].empty?
|
61
|
+
@data["embeddings"].first["values"].size
|
62
|
+
else
|
63
|
+
0
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
67
|
+
# Get the number of embeddings in the batch
|
68
|
+
# @return [Integer] The number of embeddings (1 for single embedding, N for batch)
|
69
|
+
def batch_size
|
70
|
+
if single?
|
71
|
+
1
|
72
|
+
elsif batch?
|
73
|
+
@data["embeddings"].size
|
74
|
+
else
|
75
|
+
0
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
# Get a specific embedding from the batch by index
|
80
|
+
# @param index [Integer] The index of the embedding to retrieve
|
81
|
+
# @return [Array<Float>] The embedding values at the specified index
|
82
|
+
# @raise [IndexError] If the index is out of bounds
|
83
|
+
def embedding_at(index)
|
84
|
+
if index < 0 || index >= batch_size
|
85
|
+
raise IndexError, "Index #{index} out of bounds for batch size #{batch_size}"
|
86
|
+
end
|
87
|
+
|
88
|
+
if single? && index == 0
|
89
|
+
values
|
90
|
+
elsif batch?
|
91
|
+
@data["embeddings"][index]["values"]
|
92
|
+
end
|
93
|
+
end
|
94
|
+
|
95
|
+
# Alias for embedding_at(0)
|
96
|
+
# @return [Array<Float>] The first embedding
|
97
|
+
def embedding
|
98
|
+
embedding_at(0)
|
99
|
+
end
|
100
|
+
|
101
|
+
# Calculate the cosine similarity between two embedding vectors
|
102
|
+
# @param vec1 [Array<Float>] First vector
|
103
|
+
# @param vec2 [Array<Float>] Second vector
|
104
|
+
# @return [Float] Cosine similarity (-1 to 1)
|
105
|
+
# @raise [Geminize::ValidationError] If vectors have different dimensions
|
106
|
+
def self.cosine_similarity(vec1, vec2)
|
107
|
+
Geminize::VectorUtils.cosine_similarity(vec1, vec2)
|
108
|
+
end
|
109
|
+
|
110
|
+
# Calculate the cosine similarity between two embedding indexes in this response
|
111
|
+
# @param index1 [Integer] First embedding index
|
112
|
+
# @param index2 [Integer] Second embedding index
|
113
|
+
# @return [Float] Cosine similarity (-1 to 1)
|
114
|
+
# @raise [Geminize::ValidationError] If the indexes are invalid
|
115
|
+
def similarity(index1, index2)
|
116
|
+
vec1 = embedding_at(index1)
|
117
|
+
vec2 = embedding_at(index2)
|
118
|
+
|
119
|
+
if vec1.nil?
|
120
|
+
raise Geminize::ValidationError.new("Invalid embedding index: #{index1}", "INVALID_ARGUMENT")
|
121
|
+
end
|
122
|
+
|
123
|
+
if vec2.nil?
|
124
|
+
raise Geminize::ValidationError.new("Invalid embedding index: #{index2}", "INVALID_ARGUMENT")
|
125
|
+
end
|
126
|
+
|
127
|
+
VectorUtils.cosine_similarity(vec1, vec2)
|
128
|
+
end
|
129
|
+
|
130
|
+
# Calculate the cosine similarity between an embedding in this response and another vector
|
131
|
+
# @param index [Integer] Embedding index in this response
|
132
|
+
# @param other_vector [Array<Float>] External vector to compare with
|
133
|
+
# @return [Float] Cosine similarity (-1 to 1)
|
134
|
+
# @raise [Geminize::ValidationError] If the index is invalid
|
135
|
+
def similarity_with_vector(index, other_vector)
|
136
|
+
vec = embedding_at(index)
|
137
|
+
|
138
|
+
if vec.nil?
|
139
|
+
raise Geminize::ValidationError.new("Invalid embedding index: #{index}", "INVALID_ARGUMENT")
|
140
|
+
end
|
141
|
+
|
142
|
+
VectorUtils.cosine_similarity(vec, other_vector)
|
143
|
+
end
|
144
|
+
|
145
|
+
# Compute similarity matrix for all embeddings in this response
|
146
|
+
# @param metric [Symbol] Distance metric to use (:cosine or :euclidean)
|
147
|
+
# @return [Array<Array<Float>>] Matrix of similarity scores
|
148
|
+
def similarity_matrix(metric = :cosine)
|
149
|
+
vectors = embeddings
|
150
|
+
return [] if vectors.empty?
|
151
|
+
|
152
|
+
matrix = Array.new(vectors.length) { Array.new(vectors.length, 0.0) }
|
153
|
+
|
154
|
+
vectors.each_with_index do |vec1, i|
|
155
|
+
# Diagonal is always 1 (self-similarity)
|
156
|
+
matrix[i][i] = 1.0
|
157
|
+
|
158
|
+
# Only compute upper triangular matrix, then copy to lower
|
159
|
+
((i + 1)...vectors.length).each do |j|
|
160
|
+
vec2 = vectors[j]
|
161
|
+
similarity = case metric
|
162
|
+
when :cosine
|
163
|
+
VectorUtils.cosine_similarity(vec1, vec2)
|
164
|
+
when :euclidean
|
165
|
+
# Convert to similarity (higher is more similar)
|
166
|
+
1.0 / (1.0 + VectorUtils.euclidean_distance(vec1, vec2))
|
167
|
+
else
|
168
|
+
raise Geminize::ValidationError.new(
|
169
|
+
"Unknown metric: #{metric}. Supported metrics: :cosine, :euclidean",
|
170
|
+
"INVALID_ARGUMENT"
|
171
|
+
)
|
172
|
+
end
|
173
|
+
|
174
|
+
matrix[i][j] = similarity
|
175
|
+
matrix[j][i] = similarity # Matrix is symmetric
|
176
|
+
end
|
177
|
+
end
|
178
|
+
|
179
|
+
matrix
|
180
|
+
end
|
181
|
+
|
182
|
+
# Find the most similar embeddings to a given index
|
183
|
+
# @param index [Integer] Index of the embedding to compare against
|
184
|
+
# @param top_k [Integer, nil] Number of similar embeddings to return
|
185
|
+
# @param metric [Symbol] Distance metric to use (:cosine or :euclidean)
|
186
|
+
# @return [Array<Hash>] Array of {index:, similarity:} hashes sorted by similarity
|
187
|
+
# @raise [Geminize::ValidationError] If the index is invalid
|
188
|
+
def most_similar(index, top_k = nil, metric = :cosine)
|
189
|
+
vec = embedding_at(index)
|
190
|
+
if vec.nil?
|
191
|
+
raise Geminize::ValidationError.new("Invalid embedding index: #{index}", "INVALID_ARGUMENT")
|
192
|
+
end
|
193
|
+
|
194
|
+
# Get all vectors except the target one
|
195
|
+
other_vectors = embeddings.each_with_index.map { |v, i| (i == index) ? nil : v }.compact
|
196
|
+
other_indexes = embeddings.each_with_index.map { |_, i| (i == index) ? nil : i }.compact
|
197
|
+
|
198
|
+
# Find most similar
|
199
|
+
similarities = VectorUtils.most_similar(vec, other_vectors, nil, metric)
|
200
|
+
|
201
|
+
# Map back to original indexes
|
202
|
+
similarities.each_with_index do |result, i|
|
203
|
+
result[:index] = other_indexes[result[:index]]
|
204
|
+
end
|
205
|
+
|
206
|
+
# Return top k if specified
|
207
|
+
top_k ? similarities.take(top_k) : similarities
|
208
|
+
end
|
209
|
+
|
210
|
+
# Normalize embeddings to unit length
|
211
|
+
# @return [Array<Array<Float>>] Normalized embeddings
|
212
|
+
def normalized_embeddings
|
213
|
+
embeddings.map { |v| VectorUtils.normalize(v) }
|
214
|
+
end
|
215
|
+
|
216
|
+
# Average the embeddings in this response
|
217
|
+
# @return [Array<Float>] Average embedding vector
|
218
|
+
# @raise [Geminize::ValidationError] If there are no embeddings
|
219
|
+
def average_embedding
|
220
|
+
vecs = embeddings
|
221
|
+
if vecs.empty?
|
222
|
+
raise Geminize::ValidationError.new("No embeddings found to average", "INVALID_ARGUMENT")
|
223
|
+
end
|
224
|
+
|
225
|
+
VectorUtils.average_vectors(vecs)
|
226
|
+
end
|
227
|
+
|
228
|
+
# Calculate Euclidean distance between two embeddings
|
229
|
+
# @param index1 [Integer] First embedding index
|
230
|
+
# @param index2 [Integer] Second embedding index
|
231
|
+
# @return [Float] Euclidean distance
|
232
|
+
# @raise [Geminize::ValidationError] If the indexes are invalid
|
233
|
+
def euclidean_distance(index1, index2)
|
234
|
+
vec1 = embedding_at(index1)
|
235
|
+
vec2 = embedding_at(index2)
|
236
|
+
|
237
|
+
if vec1.nil?
|
238
|
+
raise Geminize::ValidationError.new("Invalid embedding index: #{index1}", "INVALID_ARGUMENT")
|
239
|
+
end
|
240
|
+
|
241
|
+
if vec2.nil?
|
242
|
+
raise Geminize::ValidationError.new("Invalid embedding index: #{index2}", "INVALID_ARGUMENT")
|
243
|
+
end
|
244
|
+
|
245
|
+
VectorUtils.euclidean_distance(vec1, vec2)
|
246
|
+
end
|
247
|
+
|
248
|
+
# Export embeddings to a JSON string
|
249
|
+
# @param pretty [Boolean] Whether to format the JSON with indentation
|
250
|
+
# @return [String] JSON representation of the embeddings
|
251
|
+
def to_json(pretty = false)
|
252
|
+
data = {
|
253
|
+
embeddings: embeddings,
|
254
|
+
dimensions: dimensions,
|
255
|
+
count: embeddings.length
|
256
|
+
}
|
257
|
+
|
258
|
+
if pretty
|
259
|
+
JSON.pretty_generate(data)
|
260
|
+
else
|
261
|
+
JSON.generate(data)
|
262
|
+
end
|
263
|
+
end
|
264
|
+
|
265
|
+
# Export embeddings to a CSV string
|
266
|
+
# @param include_header [Boolean] Whether to include a header row with dimension indices
|
267
|
+
# @return [String] CSV representation of the embeddings
|
268
|
+
def to_csv(include_header = true)
|
269
|
+
return "" if embeddings.empty?
|
270
|
+
|
271
|
+
dim = dimensions || 0
|
272
|
+
csv_lines = []
|
273
|
+
|
274
|
+
# Add header if requested
|
275
|
+
if include_header
|
276
|
+
header = (0...dim).map { |i| "dim_#{i}" }.join(",")
|
277
|
+
csv_lines << header
|
278
|
+
end
|
279
|
+
|
280
|
+
# Add data rows
|
281
|
+
embeddings.each do |vec|
|
282
|
+
csv_lines << vec.join(",")
|
283
|
+
end
|
284
|
+
|
285
|
+
csv_lines.join("\n")
|
286
|
+
end
|
287
|
+
|
288
|
+
# Transform embeddings to a hash with specified keys
|
289
|
+
# @param keys [Array<String>, nil] Keys to associate with each vector (must match number of embeddings)
|
290
|
+
# @return [Hash] Hash mapping keys to embedding vectors
|
291
|
+
# @raise [Geminize::ValidationError] If keys count doesn't match embeddings count
|
292
|
+
def to_hash_with_keys(keys)
|
293
|
+
vecs = embeddings
|
294
|
+
|
295
|
+
if keys.nil?
|
296
|
+
# Return a hash with numeric keys if no keys provided
|
297
|
+
return vecs.each_with_index.map { |vec, i| [i.to_s, vec] }.to_h
|
298
|
+
end
|
299
|
+
|
300
|
+
unless keys.length == vecs.length
|
301
|
+
raise Geminize::ValidationError.new(
|
302
|
+
"Number of keys (#{keys.length}) doesn't match number of embeddings (#{vecs.length})",
|
303
|
+
"INVALID_ARGUMENT"
|
304
|
+
)
|
305
|
+
end
|
306
|
+
|
307
|
+
# Create hash mapping keys to vectors
|
308
|
+
keys.zip(vecs).to_h
|
309
|
+
end
|
310
|
+
|
311
|
+
# Prepare data for visualization with dimensionality reduction
|
312
|
+
# @param method [Symbol] Dimensionality reduction method (:pca or :tsne)
|
313
|
+
# @param dimensions [Integer] Number of dimensions to reduce to (1-3)
|
314
|
+
# @return [Array<Hash>] Array of points with reduced coordinates
|
315
|
+
# @note This method provides the data structure for visualization but requires external
|
316
|
+
# libraries like 'iruby' and 'numo' to perform actual dimensionality reduction
|
317
|
+
# Users should transform this data according to their visualization framework
|
318
|
+
def prepare_visualization_data(method = :pca, dimensions = 2)
|
319
|
+
unless [:pca, :tsne].include?(method)
|
320
|
+
raise Geminize::ValidationError.new(
|
321
|
+
"Unknown dimensionality reduction method: #{method}. Supported methods: :pca, :tsne",
|
322
|
+
"INVALID_ARGUMENT"
|
323
|
+
)
|
324
|
+
end
|
325
|
+
|
326
|
+
unless (1..3).cover?(dimensions)
|
327
|
+
raise Geminize::ValidationError.new(
|
328
|
+
"Dimensions must be between 1 and 3, got: #{dimensions}",
|
329
|
+
"INVALID_ARGUMENT"
|
330
|
+
)
|
331
|
+
end
|
332
|
+
|
333
|
+
if embeddings.empty?
|
334
|
+
return []
|
335
|
+
end
|
336
|
+
|
337
|
+
# This implementation just returns the structure for visualization
|
338
|
+
# The actual dimensionality reduction should be implemented by users
|
339
|
+
# with their preferred libraries
|
340
|
+
embeddings.each_with_index.map do |_, i|
|
341
|
+
{
|
342
|
+
index: i,
|
343
|
+
# These coordinates would normally be calculated by dimensionality reduction
|
344
|
+
coordinates: Array.new(dimensions) { 0.0 },
|
345
|
+
# Additional fields that would be useful for visualization
|
346
|
+
original_vector: embedding_at(i)
|
347
|
+
}
|
348
|
+
end
|
349
|
+
end
|
350
|
+
|
351
|
+
# Get the dimensionality of the embeddings
|
352
|
+
# @return [Integer, nil] The number of dimensions or nil if no embeddings
|
353
|
+
def dimensions
|
354
|
+
first = embedding
|
355
|
+
first&.length
|
356
|
+
end
|
357
|
+
|
358
|
+
# Get the total token count
|
359
|
+
# @return [Integer, nil] Total token count or nil if not available
|
360
|
+
def total_tokens
|
361
|
+
return nil unless @usage
|
362
|
+
|
363
|
+
(@usage["promptTokenCount"] || 0) + (@usage["totalTokenCount"] || 0)
|
364
|
+
end
|
365
|
+
|
366
|
+
# Get the prompt token count
|
367
|
+
# @return [Integer, nil] Prompt token count or nil if not available
|
368
|
+
def prompt_tokens
|
369
|
+
return nil unless @usage
|
370
|
+
|
371
|
+
@usage["promptTokenCount"]
|
372
|
+
end
|
373
|
+
|
374
|
+
# Create an EmbeddingResponse object from a raw API response
|
375
|
+
# @param response_data [Hash] The raw API response
|
376
|
+
# @return [EmbeddingResponse] A new EmbeddingResponse object
|
377
|
+
def self.from_hash(response_data)
|
378
|
+
new(response_data)
|
379
|
+
end
|
380
|
+
|
381
|
+
# Export embeddings to a Numpy-compatible format
|
382
|
+
# @return [Hash] A hash with ndarray compatible data structure
|
383
|
+
# @note This method provides a structure that can be easily converted to
|
384
|
+
# a numpy array in Python or used with Ruby libraries that support
|
385
|
+
# numpy-compatible formats
|
386
|
+
def to_numpy_format
|
387
|
+
{
|
388
|
+
data: embeddings,
|
389
|
+
shape: [batch_size, dimensions || 0],
|
390
|
+
dtype: "float32"
|
391
|
+
}
|
392
|
+
end
|
393
|
+
|
394
|
+
# Extract top K most significant dimensions from the embeddings
|
395
|
+
# @param k [Integer] Number of dimensions to extract
|
396
|
+
# @return [Array<Array<Float>>] Embeddings with only the top K dimensions
|
397
|
+
# @raise [Geminize::ValidationError] If K is greater than available dimensions
|
398
|
+
def top_dimensions(k)
|
399
|
+
dim = dimensions
|
400
|
+
|
401
|
+
if dim.nil? || dim == 0
|
402
|
+
raise Geminize::ValidationError.new("No embeddings found", "INVALID_ARGUMENT")
|
403
|
+
end
|
404
|
+
|
405
|
+
if k > dim
|
406
|
+
raise Geminize::ValidationError.new(
|
407
|
+
"Cannot extract #{k} dimensions from embeddings with only #{dim} dimensions",
|
408
|
+
"INVALID_ARGUMENT"
|
409
|
+
)
|
410
|
+
end
|
411
|
+
|
412
|
+
# This is a simplified approach that just takes the first K dimensions
|
413
|
+
# A more sophisticated implementation would analyze variance or importance
|
414
|
+
vecs = embeddings
|
415
|
+
vecs.map { |vec| vec.take(k) }
|
416
|
+
end
|
417
|
+
|
418
|
+
# Get metadata about the embeddings
|
419
|
+
# @return [Hash] Metadata about the embeddings including counts and token usage
|
420
|
+
def metadata
|
421
|
+
{
|
422
|
+
count: batch_size,
|
423
|
+
dimensions: dimensions,
|
424
|
+
total_tokens: total_tokens,
|
425
|
+
prompt_tokens: prompt_tokens,
|
426
|
+
is_batch: batch?,
|
427
|
+
is_single: single?
|
428
|
+
}
|
429
|
+
end
|
430
|
+
|
431
|
+
# Raw response data from the API
|
432
|
+
# @return [Hash] The complete raw API response
|
433
|
+
def raw_response
|
434
|
+
@data
|
435
|
+
end
|
436
|
+
|
437
|
+
# Iterates through each embedding with its index
|
438
|
+
# @yield [embedding, index] Block to execute for each embedding
|
439
|
+
# @yieldparam embedding [Array<Float>] The embedding vector
|
440
|
+
# @yieldparam index [Integer] The index of the embedding
|
441
|
+
# @return [Enumerator, self] Returns an enumerator if no block given, or self if block given
|
442
|
+
def each_embedding
|
443
|
+
return to_enum(:each_embedding) unless block_given?
|
444
|
+
|
445
|
+
vecs = embeddings
|
446
|
+
vecs.each_with_index do |vec, idx|
|
447
|
+
yield vec, idx
|
448
|
+
end
|
449
|
+
|
450
|
+
self
|
451
|
+
end
|
452
|
+
|
453
|
+
# Converts embeddings to a simple array
|
454
|
+
# @return [Array<Array<Float>>] Array of embedding vectors
|
455
|
+
def to_a
|
456
|
+
embeddings
|
457
|
+
end
|
458
|
+
|
459
|
+
# Associates labels/texts with embeddings
|
460
|
+
# @param labels [Array<String>] Labels to associate with embeddings
|
461
|
+
# @return [Hash] Hash mapping labels to embeddings
|
462
|
+
# @raise [Geminize::ValidationError] If the number of labels doesn't match the number of embeddings
|
463
|
+
def with_labels(labels)
|
464
|
+
unless labels.is_a?(Array)
|
465
|
+
raise Geminize::ValidationError.new("Labels must be an array", "INVALID_ARGUMENT")
|
466
|
+
end
|
467
|
+
|
468
|
+
vecs = embeddings
|
469
|
+
unless labels.length == vecs.length
|
470
|
+
raise Geminize::ValidationError.new(
|
471
|
+
"Number of labels (#{labels.length}) doesn't match number of embeddings (#{vecs.length})",
|
472
|
+
"INVALID_ARGUMENT"
|
473
|
+
)
|
474
|
+
end
|
475
|
+
|
476
|
+
# Create hash mapping labels to vectors
|
477
|
+
labels.zip(vecs).to_h
|
478
|
+
end
|
479
|
+
|
480
|
+
# Filter embeddings based on a condition
|
481
|
+
# @yield [embedding, index] Block that returns true if the embedding should be included
|
482
|
+
# @yieldparam embedding [Array<Float>] The embedding vector
|
483
|
+
# @yieldparam index [Integer] The index of the embedding
|
484
|
+
# @return [Array<Array<Float>>] Filtered embeddings
|
485
|
+
# @note This method doesn't modify the original response object
|
486
|
+
def filter
|
487
|
+
return to_enum(:filter) unless block_given?
|
488
|
+
|
489
|
+
filtered = []
|
490
|
+
each_embedding do |vec, idx|
|
491
|
+
filtered << vec if yield(vec, idx)
|
492
|
+
end
|
493
|
+
filtered
|
494
|
+
end
|
495
|
+
|
496
|
+
# Get a subset of embeddings by indices
|
497
|
+
# @param start [Integer] Start index (inclusive)
|
498
|
+
# @param finish [Integer, nil] End index (inclusive), or nil to select until the end
|
499
|
+
# @return [Array<Array<Float>>] Subset of embeddings
|
500
|
+
# @raise [IndexError] If the range is invalid
|
501
|
+
def slice(start, finish = nil)
|
502
|
+
vecs = embeddings
|
503
|
+
|
504
|
+
# Handle negative indices
|
505
|
+
start = vecs.length + start if start < 0
|
506
|
+
finish = vecs.length + finish if finish && finish < 0
|
507
|
+
finish = vecs.length - 1 if finish.nil?
|
508
|
+
|
509
|
+
# Validate range
|
510
|
+
if start < 0 || start >= vecs.length
|
511
|
+
raise IndexError, "Start index #{start} out of bounds for embeddings size #{vecs.length}"
|
512
|
+
end
|
513
|
+
|
514
|
+
if finish < start || finish >= vecs.length
|
515
|
+
raise IndexError, "End index #{finish} out of bounds for embeddings size #{vecs.length}"
|
516
|
+
end
|
517
|
+
|
518
|
+
vecs[start..finish]
|
519
|
+
end
|
520
|
+
|
521
|
+
# Combine with another EmbeddingResponse
|
522
|
+
# @param other [Geminize::Models::EmbeddingResponse] Another embedding response to combine with
|
523
|
+
# @return [Geminize::Models::EmbeddingResponse] A new combined response
|
524
|
+
# @raise [Geminize::ValidationError] If the embeddings have different dimensions
|
525
|
+
def combine(other)
|
526
|
+
unless other.is_a?(Geminize::Models::EmbeddingResponse)
|
527
|
+
raise Geminize::ValidationError.new(
|
528
|
+
"Can only combine with another EmbeddingResponse",
|
529
|
+
"INVALID_ARGUMENT"
|
530
|
+
)
|
531
|
+
end
|
532
|
+
|
533
|
+
# Check dimension compatibility
|
534
|
+
if dimensions != other.dimensions
|
535
|
+
raise Geminize::ValidationError.new(
|
536
|
+
"Cannot combine embeddings with different dimensions (#{dimensions} vs #{other.dimensions})",
|
537
|
+
"INVALID_ARGUMENT"
|
538
|
+
)
|
539
|
+
end
|
540
|
+
|
541
|
+
# Create a combined response hash
|
542
|
+
combined_hash = {
|
543
|
+
"embeddings" => [],
|
544
|
+
"usageMetadata" => {
|
545
|
+
"promptTokenCount" => 0,
|
546
|
+
"totalTokenCount" => 0
|
547
|
+
}
|
548
|
+
}
|
549
|
+
|
550
|
+
# Add embeddings from both responses
|
551
|
+
self_vecs = embeddings
|
552
|
+
other_vecs = other.embeddings
|
553
|
+
|
554
|
+
# Prepare the embeddings format
|
555
|
+
combined_embeddings = (self_vecs + other_vecs).map do |vec|
|
556
|
+
{"values" => vec}
|
557
|
+
end
|
558
|
+
|
559
|
+
combined_hash["embeddings"] = combined_embeddings
|
560
|
+
|
561
|
+
# Combine usage data if available
|
562
|
+
if @usage
|
563
|
+
combined_hash["usageMetadata"]["promptTokenCount"] += @usage["promptTokenCount"] || 0
|
564
|
+
combined_hash["usageMetadata"]["totalTokenCount"] += @usage["totalTokenCount"] || 0
|
565
|
+
end
|
566
|
+
|
567
|
+
if other.usage
|
568
|
+
combined_hash["usageMetadata"]["promptTokenCount"] += other.usage["promptTokenCount"] || 0
|
569
|
+
combined_hash["usageMetadata"]["totalTokenCount"] += other.usage["totalTokenCount"] || 0
|
570
|
+
end
|
571
|
+
|
572
|
+
# Create a new response object
|
573
|
+
self.class.from_hash(combined_hash)
|
574
|
+
end
|
575
|
+
|
576
|
+
# Save embeddings to a file
|
577
|
+
# @param path [String] Path to save the file
|
578
|
+
# @param format [Symbol] Format to save in (:json, :csv, :binary)
|
579
|
+
# @param options [Hash] Additional options for saving
|
580
|
+
# @option options [Boolean] :pretty Format JSON with indentation (for :json format)
|
581
|
+
# @option options [Boolean] :include_header Include header with dimension indices (for :csv format)
|
582
|
+
# @option options [Boolean] :include_metadata Include metadata in the saved file
|
583
|
+
# @return [Boolean] True if successful
|
584
|
+
# @raise [Geminize::ValidationError] If the format is invalid or file operations fail
|
585
|
+
def save(path, format = :json, options = {})
|
586
|
+
# Default options
|
587
|
+
options = {
|
588
|
+
pretty: false,
|
589
|
+
include_header: true,
|
590
|
+
include_metadata: true
|
591
|
+
}.merge(options)
|
592
|
+
|
593
|
+
begin
|
594
|
+
File.open(path, "w") do |file|
|
595
|
+
content = case format
|
596
|
+
when :json
|
597
|
+
data = {"embeddings" => embeddings}
|
598
|
+
data["metadata"] = metadata if options[:include_metadata]
|
599
|
+
|
600
|
+
options[:pretty] ? JSON.pretty_generate(data) : JSON.generate(data)
|
601
|
+
when :csv
|
602
|
+
to_csv(options[:include_header])
|
603
|
+
when :binary
|
604
|
+
raise Geminize::ValidationError.new(
|
605
|
+
"Binary format not yet implemented",
|
606
|
+
"INVALID_ARGUMENT"
|
607
|
+
)
|
608
|
+
else
|
609
|
+
raise Geminize::ValidationError.new(
|
610
|
+
"Unknown format: #{format}. Supported formats: :json, :csv",
|
611
|
+
"INVALID_ARGUMENT"
|
612
|
+
)
|
613
|
+
end
|
614
|
+
|
615
|
+
file.write(content)
|
616
|
+
end
|
617
|
+
|
618
|
+
true
|
619
|
+
rescue => e
|
620
|
+
raise Geminize::ValidationError.new(
|
621
|
+
"Failed to save embeddings: #{e.message}",
|
622
|
+
"IO_ERROR"
|
623
|
+
)
|
624
|
+
end
|
625
|
+
end
|
626
|
+
|
627
|
+
# Load embeddings from a file
|
628
|
+
# @param path [String] Path to the file
|
629
|
+
# @param format [Symbol, nil] Format of the file (:json, :csv, :binary)
|
630
|
+
# If nil, format will be inferred from file extension
|
631
|
+
# @return [Geminize::Models::EmbeddingResponse] A new embedding response
|
632
|
+
# @raise [Geminize::ValidationError] If the file format is invalid or file operations fail
|
633
|
+
def self.load(path, format = nil)
|
634
|
+
# Infer format from file extension if not specified
|
635
|
+
if format.nil?
|
636
|
+
ext = File.extname(path).downcase.delete(".")
|
637
|
+
format = case ext
|
638
|
+
when "json" then :json
|
639
|
+
when "csv" then :csv
|
640
|
+
when "bin" then :binary
|
641
|
+
else
|
642
|
+
raise Geminize::ValidationError.new(
|
643
|
+
"Could not infer format from file extension: #{ext}",
|
644
|
+
"INVALID_ARGUMENT"
|
645
|
+
)
|
646
|
+
end
|
647
|
+
end
|
648
|
+
|
649
|
+
begin
|
650
|
+
content = File.read(path)
|
651
|
+
|
652
|
+
case format
|
653
|
+
when :json
|
654
|
+
data = JSON.parse(content)
|
655
|
+
|
656
|
+
if data["embeddings"]
|
657
|
+
# Convert to API response format
|
658
|
+
response_data = {
|
659
|
+
"embeddings" => data["embeddings"].map { |vec| {"values" => vec} }
|
660
|
+
}
|
661
|
+
|
662
|
+
# Add usage metadata if available
|
663
|
+
if data["metadata"] && data["metadata"]["total_tokens"]
|
664
|
+
response_data["usageMetadata"] = {
|
665
|
+
"promptTokenCount" => data["metadata"]["prompt_tokens"] || 0,
|
666
|
+
"totalTokenCount" => data["metadata"]["total_tokens"] || 0
|
667
|
+
}
|
668
|
+
end
|
669
|
+
|
670
|
+
from_hash(response_data)
|
671
|
+
else
|
672
|
+
# Assume it's already in the API response format
|
673
|
+
from_hash(data)
|
674
|
+
end
|
675
|
+
when :csv
|
676
|
+
lines = content.split("\n")
|
677
|
+
|
678
|
+
# Skip header if it doesn't look like an embedding (has letters)
|
679
|
+
has_header = lines[0].match?(/[a-zA-Z]/)
|
680
|
+
|
681
|
+
# Parse vectors
|
682
|
+
vectors = lines.map.with_index do |line, idx|
|
683
|
+
next if idx == 0 && has_header
|
684
|
+
line.split(",").map(&:to_f)
|
685
|
+
end.compact
|
686
|
+
|
687
|
+
# Create a response hash
|
688
|
+
response_data = {
|
689
|
+
"embeddings" => vectors.map { |vec| {"values" => vec} }
|
690
|
+
}
|
691
|
+
|
692
|
+
from_hash(response_data)
|
693
|
+
when :binary
|
694
|
+
raise Geminize::ValidationError.new(
|
695
|
+
"Binary format not yet implemented",
|
696
|
+
"INVALID_ARGUMENT"
|
697
|
+
)
|
698
|
+
else
|
699
|
+
raise Geminize::ValidationError.new(
|
700
|
+
"Unknown format: #{format}. Supported formats: :json, :csv",
|
701
|
+
"INVALID_ARGUMENT"
|
702
|
+
)
|
703
|
+
end
|
704
|
+
rescue JSON::ParserError => e
|
705
|
+
raise Geminize::ValidationError.new(
|
706
|
+
"Failed to parse JSON: #{e.message}",
|
707
|
+
"INVALID_ARGUMENT"
|
708
|
+
)
|
709
|
+
rescue => e
|
710
|
+
raise Geminize::ValidationError.new(
|
711
|
+
"Failed to load embeddings: #{e.message}",
|
712
|
+
"IO_ERROR"
|
713
|
+
)
|
714
|
+
end
|
715
|
+
end
|
716
|
+
|
717
|
+
# Perform simple clustering of embeddings
|
718
|
+
# @param k [Integer] Number of clusters
|
719
|
+
# @param max_iterations [Integer] Maximum number of iterations for clustering
|
720
|
+
# @param metric [Symbol] Distance metric to use (:cosine or :euclidean)
|
721
|
+
# @return [Hash] Hash with :clusters (array of indices) and :centroids (cluster centers)
|
722
|
+
# @raise [Geminize::ValidationError] If clustering parameters are invalid
|
723
|
+
# @note This is a basic implementation of k-means clustering for demonstration purposes
|
724
|
+
def cluster(k, max_iterations = 100, metric = :cosine)
|
725
|
+
vecs = embeddings
|
726
|
+
|
727
|
+
if vecs.empty?
|
728
|
+
raise Geminize::ValidationError.new(
|
729
|
+
"Cannot cluster empty embeddings",
|
730
|
+
"INVALID_ARGUMENT"
|
731
|
+
)
|
732
|
+
end
|
733
|
+
|
734
|
+
if k <= 0 || k > vecs.length
|
735
|
+
raise Geminize::ValidationError.new(
|
736
|
+
"Number of clusters must be between 1 and #{vecs.length}, got: #{k}",
|
737
|
+
"INVALID_ARGUMENT"
|
738
|
+
)
|
739
|
+
end
|
740
|
+
|
741
|
+
# Normalize vectors for better clustering (especially important for cosine similarity)
|
742
|
+
normalized_vecs = vecs.map { |v| VectorUtils.normalize(v) }
|
743
|
+
|
744
|
+
# Initialize centroids using k-means++ algorithm
|
745
|
+
centroids = kmeans_plus_plus_init(normalized_vecs, k, metric)
|
746
|
+
|
747
|
+
# Initialize cluster assignments
|
748
|
+
cluster_assignments = Array.new(normalized_vecs.length, -1)
|
749
|
+
|
750
|
+
# Main K-means loop
|
751
|
+
iterations = 0
|
752
|
+
changes = true
|
753
|
+
|
754
|
+
while changes && iterations < max_iterations
|
755
|
+
changes = false
|
756
|
+
|
757
|
+
# Assign points to clusters
|
758
|
+
normalized_vecs.each_with_index do |vec, idx|
|
759
|
+
best_distance = -Float::INFINITY
|
760
|
+
best_cluster = -1
|
761
|
+
|
762
|
+
centroids.each_with_index do |centroid, cluster_idx|
|
763
|
+
# Calculate similarity (higher is better)
|
764
|
+
similarity = case metric
|
765
|
+
when :cosine
|
766
|
+
VectorUtils.cosine_similarity(vec, centroid)
|
767
|
+
when :euclidean
|
768
|
+
# Convert to similarity (higher is more similar)
|
769
|
+
1.0 / (1.0 + VectorUtils.euclidean_distance(vec, centroid))
|
770
|
+
else
|
771
|
+
raise Geminize::ValidationError.new(
|
772
|
+
"Unknown metric: #{metric}. Supported metrics: :cosine, :euclidean",
|
773
|
+
"INVALID_ARGUMENT"
|
774
|
+
)
|
775
|
+
end
|
776
|
+
|
777
|
+
if similarity > best_distance
|
778
|
+
best_distance = similarity
|
779
|
+
best_cluster = cluster_idx
|
780
|
+
end
|
781
|
+
end
|
782
|
+
|
783
|
+
# Update cluster assignment if it changed
|
784
|
+
if cluster_assignments[idx] != best_cluster
|
785
|
+
cluster_assignments[idx] = best_cluster
|
786
|
+
changes = true
|
787
|
+
end
|
788
|
+
end
|
789
|
+
|
790
|
+
# Update centroids
|
791
|
+
new_centroids = Array.new(k) { [] }
|
792
|
+
|
793
|
+
# Collect points for each cluster
|
794
|
+
normalized_vecs.each_with_index do |vec, idx|
|
795
|
+
cluster_idx = cluster_assignments[idx]
|
796
|
+
new_centroids[cluster_idx] << vec if cluster_idx >= 0
|
797
|
+
end
|
798
|
+
|
799
|
+
# Calculate new centroids (average of points in each cluster)
|
800
|
+
new_centroids.each_with_index do |cluster_points, idx|
|
801
|
+
if cluster_points.empty?
|
802
|
+
# If a cluster is empty, reinitialize with a point farthest from other centroids
|
803
|
+
farthest_idx = find_farthest_point(normalized_vecs, centroids, cluster_assignments)
|
804
|
+
centroids[idx] = normalized_vecs[farthest_idx].dup
|
805
|
+
else
|
806
|
+
# Otherwise take the average and normalize
|
807
|
+
avg = VectorUtils.average_vectors(cluster_points)
|
808
|
+
centroids[idx] = VectorUtils.normalize(avg)
|
809
|
+
end
|
810
|
+
end
|
811
|
+
|
812
|
+
iterations += 1
|
813
|
+
end
|
814
|
+
|
815
|
+
# Organize results by cluster
|
816
|
+
clusters = Array.new(k) { [] }
|
817
|
+
cluster_assignments.each_with_index do |cluster_idx, idx|
|
818
|
+
clusters[cluster_idx] << idx if cluster_idx >= 0
|
819
|
+
end
|
820
|
+
|
821
|
+
{
|
822
|
+
clusters: clusters,
|
823
|
+
centroids: centroids,
|
824
|
+
iterations: iterations,
|
825
|
+
metric: metric
|
826
|
+
}
|
827
|
+
end
|
828
|
+
|
829
|
+
# Resize embeddings to a different dimension
|
830
|
+
# @param new_dim [Integer] New dimension size
|
831
|
+
# @param method [Symbol] Method to use for resizing (:truncate, :pad)
|
832
|
+
# @param pad_value [Float] Value to use for padding when using :pad method
|
833
|
+
# @return [Array<Array<Float>>] Resized embeddings
|
834
|
+
# @raise [Geminize::ValidationError] If the resize parameters are invalid
|
835
|
+
def resize(new_dim, method = :truncate, pad_value = 0.0)
|
836
|
+
vecs = embeddings
|
837
|
+
|
838
|
+
if vecs.empty?
|
839
|
+
raise Geminize::ValidationError.new(
|
840
|
+
"Cannot resize empty embeddings",
|
841
|
+
"INVALID_ARGUMENT"
|
842
|
+
)
|
843
|
+
end
|
844
|
+
|
845
|
+
if new_dim <= 0
|
846
|
+
raise Geminize::ValidationError.new(
|
847
|
+
"New dimension must be positive, got: #{new_dim}",
|
848
|
+
"INVALID_ARGUMENT"
|
849
|
+
)
|
850
|
+
end
|
851
|
+
|
852
|
+
unless [:truncate, :pad].include?(method)
|
853
|
+
raise Geminize::ValidationError.new(
|
854
|
+
"Unknown resize method: #{method}. Supported methods: :truncate, :pad",
|
855
|
+
"INVALID_ARGUMENT"
|
856
|
+
)
|
857
|
+
end
|
858
|
+
|
859
|
+
current_dim = dimensions
|
860
|
+
|
861
|
+
case method
|
862
|
+
when :truncate
|
863
|
+
if new_dim > current_dim
|
864
|
+
# If truncating but new_dim is larger, pad with zeros
|
865
|
+
vecs.map do |vec|
|
866
|
+
vec + Array.new(new_dim - current_dim, pad_value)
|
867
|
+
end
|
868
|
+
else
|
869
|
+
# Otherwise truncate
|
870
|
+
vecs.map { |vec| vec.take(new_dim) }
|
871
|
+
end
|
872
|
+
when :pad
|
873
|
+
if new_dim > current_dim
|
874
|
+
# Pad with specified value
|
875
|
+
vecs.map do |vec|
|
876
|
+
vec + Array.new(new_dim - current_dim, pad_value)
|
877
|
+
end
|
878
|
+
else
|
879
|
+
# Truncate if new_dim is smaller
|
880
|
+
vecs.map { |vec| vec.take(new_dim) }
|
881
|
+
end
|
882
|
+
end
|
883
|
+
end
|
884
|
+
|
885
|
+
# Apply a transformation to all embeddings
|
886
|
+
# @yield [embedding, index] Block that transforms a single embedding
|
887
|
+
# @yieldparam embedding [Array<Float>] The embedding vector
|
888
|
+
# @yieldparam index [Integer] The index of the embedding
|
889
|
+
# @yieldreturn [Array<Float>] The transformed embedding
|
890
|
+
# @return [Array<Array<Float>>] Transformed embeddings
|
891
|
+
# @raise [Geminize::ValidationError] If the transformation is invalid
|
892
|
+
def map_embeddings
|
893
|
+
return to_enum(:map_embeddings) unless block_given?
|
894
|
+
|
895
|
+
vecs = embeddings
|
896
|
+
result = []
|
897
|
+
|
898
|
+
vecs.each_with_index do |vec, idx|
|
899
|
+
transformed = yield(vec, idx)
|
900
|
+
|
901
|
+
unless transformed.is_a?(Array)
|
902
|
+
raise Geminize::ValidationError.new(
|
903
|
+
"Transformation must return an array, got: #{transformed.class}",
|
904
|
+
"INVALID_ARGUMENT"
|
905
|
+
)
|
906
|
+
end
|
907
|
+
|
908
|
+
result << transformed
|
909
|
+
end
|
910
|
+
|
911
|
+
result
|
912
|
+
end
|
913
|
+
|
914
|
+
private
|
915
|
+
|
916
|
+
# Validate the response data
|
917
|
+
# @raise [Geminize::ValidationError] If the data doesn't contain valid embeddings
|
918
|
+
def validate!
|
919
|
+
# Ensure we have embedding data
|
920
|
+
if !single? && !batch?
|
921
|
+
raise Geminize::ValidationError.new("No embedding data found", "INVALID_RESPONSE")
|
922
|
+
end
|
923
|
+
|
924
|
+
# For single embeddings, ensure values is an array
|
925
|
+
if single? && !@data["embedding"]["values"].is_a?(Array)
|
926
|
+
raise Geminize::ValidationError.new("Embedding values must be an array", "INVALID_RESPONSE")
|
927
|
+
end
|
928
|
+
|
929
|
+
# For batch embeddings, validate each embedding
|
930
|
+
if batch?
|
931
|
+
if @data["embeddings"].empty?
|
932
|
+
raise Geminize::ValidationError.new("Empty embeddings array", "INVALID_RESPONSE")
|
933
|
+
end
|
934
|
+
|
935
|
+
# Check that all embeddings have values as arrays
|
936
|
+
@data["embeddings"].each_with_index do |emb, i|
|
937
|
+
unless emb.is_a?(Hash) && emb.has_key?("values")
|
938
|
+
raise Geminize::ValidationError.new(
|
939
|
+
"Embedding at index #{i} must have 'values' key",
|
940
|
+
"INVALID_RESPONSE"
|
941
|
+
)
|
942
|
+
end
|
943
|
+
|
944
|
+
unless emb["values"].is_a?(Array)
|
945
|
+
raise Geminize::ValidationError.new(
|
946
|
+
"Embedding values at index #{i} must be an array",
|
947
|
+
"INVALID_RESPONSE"
|
948
|
+
)
|
949
|
+
end
|
950
|
+
end
|
951
|
+
|
952
|
+
# Check that all embeddings have the same size
|
953
|
+
sizes = @data["embeddings"].map { |emb| emb["values"].size }
|
954
|
+
if sizes.uniq.size != 1
|
955
|
+
raise Geminize::ValidationError.new("Inconsistent embedding sizes", "INVALID_RESPONSE")
|
956
|
+
end
|
957
|
+
end
|
958
|
+
end
|
959
|
+
|
960
|
+
# Parse the response data and extract relevant information
|
961
|
+
def parse_response
|
962
|
+
parse_usage
|
963
|
+
end
|
964
|
+
|
965
|
+
# Parse usage information from the response
|
966
|
+
def parse_usage
|
967
|
+
@usage = @data["usageMetadata"] if @data["usageMetadata"]
|
968
|
+
end
|
969
|
+
|
970
|
+
# Initialize centroids using k-means++ algorithm
|
971
|
+
# @param vectors [Array<Array<Float>>] Input vectors
|
972
|
+
# @param k [Integer] Number of clusters
|
973
|
+
# @param metric [Symbol] Distance metric to use
|
974
|
+
# @return [Array<Array<Float>>] Initial centroids
|
975
|
+
def kmeans_plus_plus_init(vectors, k, metric)
|
976
|
+
# Choose first centroid randomly
|
977
|
+
centroids = [vectors[rand(vectors.length)].dup]
|
978
|
+
|
979
|
+
# Choose remaining centroids
|
980
|
+
(k - 1).times do
|
981
|
+
# Calculate distances from each point to nearest centroid
|
982
|
+
distances = vectors.map do |vec|
|
983
|
+
# Find distance to closest centroid
|
984
|
+
best_distance = -Float::INFINITY
|
985
|
+
|
986
|
+
centroids.each do |centroid|
|
987
|
+
similarity = case metric
|
988
|
+
when :cosine
|
989
|
+
VectorUtils.cosine_similarity(vec, centroid)
|
990
|
+
when :euclidean
|
991
|
+
1.0 / (1.0 + VectorUtils.euclidean_distance(vec, centroid))
|
992
|
+
end
|
993
|
+
|
994
|
+
best_distance = [best_distance, similarity].max
|
995
|
+
end
|
996
|
+
|
997
|
+
# Convert similarity to distance (lower is better for selection)
|
998
|
+
1.0 - best_distance
|
999
|
+
end
|
1000
|
+
|
1001
|
+
# Calculate selection probabilities (higher distance = higher probability)
|
1002
|
+
sum_distances = distances.sum
|
1003
|
+
|
1004
|
+
# Guard against division by zero
|
1005
|
+
if sum_distances.zero?
|
1006
|
+
# If all points are identical to centroids, choose randomly
|
1007
|
+
next_idx = rand(vectors.length)
|
1008
|
+
else
|
1009
|
+
# Choose next centroid with probability proportional to squared distance
|
1010
|
+
probabilities = distances.map { |d| (d / sum_distances)**2 }
|
1011
|
+
cumulative_prob = 0.0
|
1012
|
+
threshold = rand
|
1013
|
+
next_idx = 0
|
1014
|
+
|
1015
|
+
probabilities.each_with_index do |prob, idx|
|
1016
|
+
cumulative_prob += prob
|
1017
|
+
if cumulative_prob >= threshold
|
1018
|
+
next_idx = idx
|
1019
|
+
break
|
1020
|
+
end
|
1021
|
+
end
|
1022
|
+
end
|
1023
|
+
|
1024
|
+
centroids << vectors[next_idx].dup
|
1025
|
+
end
|
1026
|
+
|
1027
|
+
centroids
|
1028
|
+
end
|
1029
|
+
|
1030
|
+
# Find the point farthest from existing centroids
|
1031
|
+
# @param vectors [Array<Array<Float>>] Input vectors
|
1032
|
+
# @param centroids [Array<Array<Float>>] Current centroids
|
1033
|
+
# @param assignments [Array<Integer>] Current cluster assignments
|
1034
|
+
# @return [Integer] Index of the farthest point
|
1035
|
+
def find_farthest_point(vectors, centroids, assignments)
|
1036
|
+
max_distance = -Float::INFINITY
|
1037
|
+
farthest_idx = 0
|
1038
|
+
|
1039
|
+
vectors.each_with_index do |vec, idx|
|
1040
|
+
# Skip points already assigned as centroids
|
1041
|
+
next if centroids.any? { |c| c == vec }
|
1042
|
+
|
1043
|
+
# Find minimum similarity to any centroid
|
1044
|
+
min_similarity = Float::INFINITY
|
1045
|
+
|
1046
|
+
centroids.each do |centroid|
|
1047
|
+
similarity = VectorUtils.cosine_similarity(vec, centroid)
|
1048
|
+
min_similarity = [min_similarity, similarity].min
|
1049
|
+
end
|
1050
|
+
|
1051
|
+
# Convert to distance
|
1052
|
+
distance = 1.0 - min_similarity
|
1053
|
+
|
1054
|
+
if distance > max_distance
|
1055
|
+
max_distance = distance
|
1056
|
+
farthest_idx = idx
|
1057
|
+
end
|
1058
|
+
end
|
1059
|
+
|
1060
|
+
farthest_idx
|
1061
|
+
end
|
1062
|
+
end
|
1063
|
+
end
|
1064
|
+
end
|