geminize 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/.standard.yml +3 -0
- data/.yardopts +14 -0
- data/CHANGELOG.md +24 -0
- data/CODE_OF_CONDUCT.md +132 -0
- data/CONTRIBUTING.md +109 -0
- data/LICENSE.txt +21 -0
- data/README.md +423 -0
- data/Rakefile +10 -0
- data/examples/README.md +75 -0
- data/examples/configuration.rb +58 -0
- data/examples/embeddings.rb +195 -0
- data/examples/multimodal.rb +126 -0
- data/examples/rails_chat/README.md +69 -0
- data/examples/rails_chat/app/controllers/chat_controller.rb +26 -0
- data/examples/rails_chat/app/views/chat/index.html.erb +112 -0
- data/examples/rails_chat/config/routes.rb +8 -0
- data/examples/rails_initializer.rb +46 -0
- data/examples/system_instructions.rb +101 -0
- data/lib/geminize/chat.rb +98 -0
- data/lib/geminize/client.rb +318 -0
- data/lib/geminize/configuration.rb +98 -0
- data/lib/geminize/conversation_repository.rb +161 -0
- data/lib/geminize/conversation_service.rb +126 -0
- data/lib/geminize/embeddings.rb +145 -0
- data/lib/geminize/error_mapper.rb +96 -0
- data/lib/geminize/error_parser.rb +120 -0
- data/lib/geminize/errors.rb +185 -0
- data/lib/geminize/middleware/error_handler.rb +72 -0
- data/lib/geminize/model_info.rb +91 -0
- data/lib/geminize/models/chat_request.rb +186 -0
- data/lib/geminize/models/chat_response.rb +118 -0
- data/lib/geminize/models/content_request.rb +530 -0
- data/lib/geminize/models/content_response.rb +99 -0
- data/lib/geminize/models/conversation.rb +156 -0
- data/lib/geminize/models/embedding_request.rb +222 -0
- data/lib/geminize/models/embedding_response.rb +1064 -0
- data/lib/geminize/models/memory.rb +88 -0
- data/lib/geminize/models/message.rb +140 -0
- data/lib/geminize/models/model.rb +171 -0
- data/lib/geminize/models/model_list.rb +124 -0
- data/lib/geminize/models/stream_response.rb +99 -0
- data/lib/geminize/rails/app/controllers/concerns/geminize/controller.rb +105 -0
- data/lib/geminize/rails/app/helpers/geminize_helper.rb +125 -0
- data/lib/geminize/rails/controller_additions.rb +41 -0
- data/lib/geminize/rails/engine.rb +29 -0
- data/lib/geminize/rails/helper_additions.rb +37 -0
- data/lib/geminize/rails.rb +50 -0
- data/lib/geminize/railtie.rb +33 -0
- data/lib/geminize/request_builder.rb +57 -0
- data/lib/geminize/text_generation.rb +285 -0
- data/lib/geminize/validators.rb +150 -0
- data/lib/geminize/vector_utils.rb +164 -0
- data/lib/geminize/version.rb +5 -0
- data/lib/geminize.rb +527 -0
- data/lib/generators/geminize/install_generator.rb +22 -0
- data/lib/generators/geminize/templates/README +31 -0
- data/lib/generators/geminize/templates/initializer.rb +38 -0
- data/sig/geminize.rbs +4 -0
- metadata +218 -0
@@ -0,0 +1,285 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Geminize
|
4
|
+
# Class for text generation functionality
|
5
|
+
class TextGeneration
|
6
|
+
# @return [Geminize::Client] The client instance
|
7
|
+
attr_reader :client
|
8
|
+
|
9
|
+
# Initialize a new text generation instance
|
10
|
+
# @param client [Geminize::Client, nil] The client to use (optional)
|
11
|
+
# @param options [Hash] Additional options
|
12
|
+
def initialize(client = nil, options = {})
|
13
|
+
@client = client || Client.new(options)
|
14
|
+
@options = options
|
15
|
+
end
|
16
|
+
|
17
|
+
# Generate text based on a content request
|
18
|
+
# @param content_request [Geminize::Models::ContentRequest] The content request
|
19
|
+
# @return [Geminize::Models::ContentResponse] The generation response
|
20
|
+
# @raise [Geminize::GeminizeError] If the request fails
|
21
|
+
def generate(content_request)
|
22
|
+
model_name = content_request.model_name
|
23
|
+
endpoint = RequestBuilder.build_text_generation_endpoint(model_name)
|
24
|
+
payload = RequestBuilder.build_text_generation_request(content_request)
|
25
|
+
|
26
|
+
response_data = @client.post(endpoint, payload)
|
27
|
+
Models::ContentResponse.from_hash(response_data)
|
28
|
+
end
|
29
|
+
|
30
|
+
# Generate text from a prompt string with optional parameters
|
31
|
+
# @param prompt [String] The input prompt
|
32
|
+
# @param model_name [String, nil] The model to use (optional)
|
33
|
+
# @param params [Hash] Additional generation parameters
|
34
|
+
# @option params [Float] :temperature Controls randomness (0.0-1.0)
|
35
|
+
# @option params [Integer] :max_tokens Maximum tokens to generate
|
36
|
+
# @option params [Float] :top_p Top-p value for nucleus sampling (0.0-1.0)
|
37
|
+
# @option params [Integer] :top_k Top-k value for sampling
|
38
|
+
# @option params [Array<String>] :stop_sequences Stop sequences to end generation
|
39
|
+
# @option params [String] :system_instruction System instruction to guide model behavior
|
40
|
+
# @return [Geminize::Models::ContentResponse] The generation response
|
41
|
+
# @raise [Geminize::GeminizeError] If the request fails
|
42
|
+
def generate_text(prompt, model_name = nil, params = {})
|
43
|
+
model = model_name || Geminize.configuration.default_model
|
44
|
+
|
45
|
+
content_request = Models::ContentRequest.new(
|
46
|
+
prompt,
|
47
|
+
model,
|
48
|
+
params
|
49
|
+
)
|
50
|
+
|
51
|
+
generate(content_request)
|
52
|
+
end
|
53
|
+
|
54
|
+
# Generate a streaming text response from a prompt string with optional parameters
|
55
|
+
# @param prompt [String] The input prompt
|
56
|
+
# @param model_name [String, nil] The model to use (optional)
|
57
|
+
# @param params [Hash] Additional generation parameters
|
58
|
+
# @option params [Float] :temperature Controls randomness (0.0-1.0)
|
59
|
+
# @option params [Integer] :max_tokens Maximum tokens to generate
|
60
|
+
# @option params [Float] :top_p Top-p value for nucleus sampling (0.0-1.0)
|
61
|
+
# @option params [Integer] :top_k Top-k value for sampling
|
62
|
+
# @option params [Array<String>] :stop_sequences Stop sequences to end generation
|
63
|
+
# @option params [Symbol] :stream_mode Mode for processing stream chunks (:raw, :incremental, or :delta)
|
64
|
+
# @option params [String] :system_instruction System instruction to guide model behavior
|
65
|
+
# @yield [chunk] Yields each chunk of the streaming response
|
66
|
+
# @yieldparam chunk [String, Hash, StreamResponse] A chunk of the response
|
67
|
+
# @return [void]
|
68
|
+
# @raise [Geminize::GeminizeError] If the request fails
|
69
|
+
# @example Generate text with streaming, yielding each chunk
|
70
|
+
# text_generation.generate_text_stream("Tell me a story") do |chunk|
|
71
|
+
# puts chunk
|
72
|
+
# end
|
73
|
+
# @example Generate text with incremental mode
|
74
|
+
# accumulated_text = ""
|
75
|
+
# text_generation.generate_text_stream("Tell me a story", nil, stream_mode: :incremental) do |text|
|
76
|
+
# # text contains the full response so far
|
77
|
+
# print "\r#{text}"
|
78
|
+
# end
|
79
|
+
# @example Generate text with delta mode (only new content)
|
80
|
+
# text_generation.generate_text_stream("Tell me a story", nil, stream_mode: :delta) do |new_text|
|
81
|
+
# # new_text contains only the new content in this chunk
|
82
|
+
# print new_text
|
83
|
+
# end
|
84
|
+
def generate_text_stream(prompt, model_name = nil, params = {}, &block)
|
85
|
+
raise ArgumentError, "A block is required for streaming" unless block_given?
|
86
|
+
|
87
|
+
# Extract stream processing mode
|
88
|
+
stream_mode = params.delete(:stream_mode) || :incremental
|
89
|
+
unless [:raw, :incremental, :delta].include?(stream_mode)
|
90
|
+
raise ArgumentError, "Invalid stream_mode. Must be :raw, :incremental, or :delta"
|
91
|
+
end
|
92
|
+
|
93
|
+
# Create the content request
|
94
|
+
content_request = Models::ContentRequest.new(
|
95
|
+
prompt,
|
96
|
+
model_name || Geminize.configuration.default_model,
|
97
|
+
params
|
98
|
+
)
|
99
|
+
|
100
|
+
# Generate with streaming
|
101
|
+
generate_stream(content_request, stream_mode, &block)
|
102
|
+
end
|
103
|
+
|
104
|
+
# Generate content with both text and images
|
105
|
+
# @param prompt [String] The input prompt text
|
106
|
+
# @param images [Array<Hash>] Array of image data hashes
|
107
|
+
# @param model_name [String, nil] The model to use (optional)
|
108
|
+
# @param params [Hash] Additional generation parameters
|
109
|
+
# @option params [Float] :temperature Controls randomness (0.0-1.0)
|
110
|
+
# @option params [Integer] :max_tokens Maximum tokens to generate
|
111
|
+
# @option params [Float] :top_p Top-p value for nucleus sampling (0.0-1.0)
|
112
|
+
# @option params [Integer] :top_k Top-k value for sampling
|
113
|
+
# @option params [Array<String>] :stop_sequences Stop sequences to end generation
|
114
|
+
# @option images [Hash] :source_type Source type for image ('file', 'bytes', or 'url')
|
115
|
+
# @option images [String] :data File path, raw bytes, or URL depending on source_type
|
116
|
+
# @option images [String] :mime_type MIME type for the image (optional for file and url)
|
117
|
+
# @return [Geminize::Models::ContentResponse] The generation response
|
118
|
+
# @raise [Geminize::GeminizeError] If the request fails
|
119
|
+
# @example Generate with an image file
|
120
|
+
# generate_text_multimodal("Describe this image", [{source_type: 'file', data: 'path/to/image.jpg'}])
|
121
|
+
# @example Generate with multiple images
|
122
|
+
# generate_text_multimodal("Compare these images", [
|
123
|
+
# {source_type: 'file', data: 'path/to/image1.jpg'},
|
124
|
+
# {source_type: 'url', data: 'https://example.com/image2.jpg'}
|
125
|
+
# ])
|
126
|
+
def generate_text_multimodal(prompt, images, model_name = nil, params = {})
|
127
|
+
# Create a new content request with the prompt text
|
128
|
+
content_request = Models::ContentRequest.new(
|
129
|
+
prompt,
|
130
|
+
model_name || Geminize.configuration.default_model,
|
131
|
+
params
|
132
|
+
)
|
133
|
+
|
134
|
+
# Add each image to the request based on its source type
|
135
|
+
images.each do |image|
|
136
|
+
case image[:source_type]
|
137
|
+
when "file"
|
138
|
+
content_request.add_image_from_file(image[:data])
|
139
|
+
when "bytes"
|
140
|
+
content_request.add_image_from_bytes(image[:data], image[:mime_type])
|
141
|
+
when "url"
|
142
|
+
content_request.add_image_from_url(image[:data])
|
143
|
+
else
|
144
|
+
raise Geminize::ValidationError.new(
|
145
|
+
"Invalid image source type: #{image[:source_type]}. Must be 'file', 'bytes', or 'url'",
|
146
|
+
"INVALID_ARGUMENT"
|
147
|
+
)
|
148
|
+
end
|
149
|
+
end
|
150
|
+
|
151
|
+
# Generate content with the constructed multimodal request
|
152
|
+
generate(content_request)
|
153
|
+
end
|
154
|
+
|
155
|
+
# Generate text with retries for transient errors
|
156
|
+
# @param content_request [Geminize::Models::ContentRequest] The content request
|
157
|
+
# @param max_retries [Integer] Maximum number of retry attempts
|
158
|
+
# @param retry_delay [Float] Delay between retries in seconds
|
159
|
+
# @return [Geminize::Models::ContentResponse] The generation response
|
160
|
+
# @raise [Geminize::GeminizeError] If all retry attempts fail
|
161
|
+
def generate_with_retries(content_request, max_retries = 3, retry_delay = 1.0)
|
162
|
+
retries = 0
|
163
|
+
|
164
|
+
begin
|
165
|
+
generate(content_request)
|
166
|
+
rescue Geminize::RateLimitError, Geminize::ServerError => e
|
167
|
+
if retries < max_retries
|
168
|
+
retries += 1
|
169
|
+
sleep retry_delay * retries # Exponential backoff
|
170
|
+
retry
|
171
|
+
else
|
172
|
+
raise e
|
173
|
+
end
|
174
|
+
end
|
175
|
+
end
|
176
|
+
|
177
|
+
# Cancel the current streaming operation, if any
|
178
|
+
# @return [Boolean] true if a streaming operation was cancelled, false if none was in progress
|
179
|
+
def cancel_streaming
|
180
|
+
return false unless @client
|
181
|
+
|
182
|
+
# Set the cancel_streaming flag to true on the client
|
183
|
+
@client.cancel_streaming = true
|
184
|
+
end
|
185
|
+
|
186
|
+
private
|
187
|
+
|
188
|
+
# Generate text with streaming from a content request
|
189
|
+
# @param content_request [Geminize::Models::ContentRequest] The content request
|
190
|
+
# @param stream_mode [Symbol] The stream processing mode (:raw, :incremental, or :delta)
|
191
|
+
# @yield [chunk] Yields each chunk of the streaming response
|
192
|
+
# @yieldparam chunk [String, Hash, StreamResponse] A chunk of the response
|
193
|
+
# @return [void]
|
194
|
+
# @raise [Geminize::GeminizeError] If the request fails
|
195
|
+
def generate_stream(content_request, stream_mode = :incremental, &block)
|
196
|
+
model_name = content_request.model_name
|
197
|
+
endpoint = RequestBuilder.build_streaming_endpoint(model_name)
|
198
|
+
payload = RequestBuilder.build_text_generation_request(content_request)
|
199
|
+
|
200
|
+
# For incremental mode, we'll accumulate the response
|
201
|
+
accumulated_text = "" if [:incremental, :delta].include?(stream_mode)
|
202
|
+
|
203
|
+
# Track if we received any non-error chunks
|
204
|
+
received_successful_chunks = false
|
205
|
+
|
206
|
+
begin
|
207
|
+
@client.post_stream(endpoint, payload) do |chunk|
|
208
|
+
received_successful_chunks = true
|
209
|
+
|
210
|
+
case stream_mode
|
211
|
+
when :raw
|
212
|
+
# Raw mode - yield the chunk as-is
|
213
|
+
yield chunk
|
214
|
+
when :incremental
|
215
|
+
# Incremental mode - extract and accumulate text
|
216
|
+
stream_response = Models::StreamResponse.from_hash(chunk)
|
217
|
+
|
218
|
+
# Only process and yield if there's text content in this chunk
|
219
|
+
if stream_response.text
|
220
|
+
accumulated_text += stream_response.text
|
221
|
+
yield accumulated_text
|
222
|
+
end
|
223
|
+
|
224
|
+
# If this is the final chunk with a finish reason or usage metrics, yield them
|
225
|
+
if stream_response.final_chunk? && stream_response.has_usage_metrics?
|
226
|
+
# Yield a hash with usage metrics and the final text
|
227
|
+
yield({
|
228
|
+
text: accumulated_text,
|
229
|
+
finish_reason: stream_response.finish_reason,
|
230
|
+
usage: {
|
231
|
+
prompt_tokens: stream_response.prompt_tokens,
|
232
|
+
completion_tokens: stream_response.completion_tokens,
|
233
|
+
total_tokens: stream_response.total_tokens
|
234
|
+
}
|
235
|
+
})
|
236
|
+
end
|
237
|
+
when :delta
|
238
|
+
# Delta mode - extract and yield only the new text
|
239
|
+
stream_response = Models::StreamResponse.from_hash(chunk)
|
240
|
+
|
241
|
+
# Only process and yield if there's text content in this chunk
|
242
|
+
if stream_response.text
|
243
|
+
previous_length = accumulated_text.length
|
244
|
+
accumulated_text += stream_response.text
|
245
|
+
|
246
|
+
# Extract only the new content and yield it
|
247
|
+
new_content = accumulated_text[previous_length..]
|
248
|
+
yield new_content unless new_content.empty?
|
249
|
+
end
|
250
|
+
|
251
|
+
# If this is the final chunk with a finish reason or usage metrics, yield them
|
252
|
+
if stream_response.final_chunk? && stream_response.has_usage_metrics?
|
253
|
+
# Yield a hash with usage metrics and the final text
|
254
|
+
yield({
|
255
|
+
text: accumulated_text,
|
256
|
+
finish_reason: stream_response.finish_reason,
|
257
|
+
usage: {
|
258
|
+
prompt_tokens: stream_response.prompt_tokens,
|
259
|
+
completion_tokens: stream_response.completion_tokens,
|
260
|
+
total_tokens: stream_response.total_tokens
|
261
|
+
}
|
262
|
+
})
|
263
|
+
end
|
264
|
+
end
|
265
|
+
end
|
266
|
+
rescue StreamingError, StreamingInterruptedError, StreamingTimeoutError, InvalidStreamFormatError => e
|
267
|
+
# For specialized streaming errors, add context and re-raise
|
268
|
+
error_message = "Streaming error: #{e.message}"
|
269
|
+
|
270
|
+
# If we had already received some chunks, add the partial response to the error information
|
271
|
+
if received_successful_chunks && accumulated_text && !accumulated_text.empty?
|
272
|
+
error_message += " (Partial response received: #{accumulated_text.length} characters)"
|
273
|
+
raise e.class.new(error_message, e.code, e.http_status)
|
274
|
+
else
|
275
|
+
# No chunks received, just re-raise the original error
|
276
|
+
raise
|
277
|
+
end
|
278
|
+
rescue => e
|
279
|
+
# For other errors, wrap in a GeminizeError
|
280
|
+
error_message = "Error during text generation streaming: #{e.message}"
|
281
|
+
raise GeminizeError.new(error_message)
|
282
|
+
end
|
283
|
+
end
|
284
|
+
end
|
285
|
+
end
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Geminize
|
4
|
+
# Utility module for validating parameters
|
5
|
+
module Validators
|
6
|
+
class << self
|
7
|
+
# Validate that a value is a string and not empty
|
8
|
+
# @param value [Object] The value to validate
|
9
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
10
|
+
# @raise [Geminize::ValidationError] If validation fails
|
11
|
+
# @return [void]
|
12
|
+
def validate_string!(value, param_name)
|
13
|
+
if value.nil?
|
14
|
+
raise Geminize::ValidationError.new("#{param_name} cannot be nil", "INVALID_ARGUMENT")
|
15
|
+
end
|
16
|
+
|
17
|
+
unless value.is_a?(String)
|
18
|
+
raise Geminize::ValidationError.new("#{param_name} must be a string", "INVALID_ARGUMENT")
|
19
|
+
end
|
20
|
+
end
|
21
|
+
|
22
|
+
# Validate that a string is not empty
|
23
|
+
# @param value [String] The string to validate
|
24
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
25
|
+
# @raise [Geminize::ValidationError] If validation fails
|
26
|
+
# @return [void]
|
27
|
+
def validate_not_empty!(value, param_name)
|
28
|
+
validate_string!(value, param_name)
|
29
|
+
|
30
|
+
if value.empty?
|
31
|
+
raise Geminize::ValidationError.new("#{param_name} cannot be empty", "INVALID_ARGUMENT")
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
# Validate that a value is a number in the specified range
|
36
|
+
# @param value [Object] The value to validate
|
37
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
38
|
+
# @param min [Numeric, nil] The minimum allowed value (inclusive)
|
39
|
+
# @param max [Numeric, nil] The maximum allowed value (inclusive)
|
40
|
+
# @raise [Geminize::ValidationError] If validation fails
|
41
|
+
# @return [void]
|
42
|
+
def validate_numeric!(value, param_name, min: nil, max: nil)
|
43
|
+
return if value.nil?
|
44
|
+
|
45
|
+
unless value.is_a?(Numeric)
|
46
|
+
raise Geminize::ValidationError.new("#{param_name} must be a number", "INVALID_ARGUMENT")
|
47
|
+
end
|
48
|
+
|
49
|
+
if min && value < min
|
50
|
+
raise Geminize::ValidationError.new("#{param_name} must be at least #{min}", "INVALID_ARGUMENT")
|
51
|
+
end
|
52
|
+
|
53
|
+
if max && value > max
|
54
|
+
raise Geminize::ValidationError.new("#{param_name} must be at most #{max}", "INVALID_ARGUMENT")
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
# Validate that a value is an integer in the specified range
|
59
|
+
# @param value [Object] The value to validate
|
60
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
61
|
+
# @param min [Integer, nil] The minimum allowed value (inclusive)
|
62
|
+
# @param max [Integer, nil] The maximum allowed value (inclusive)
|
63
|
+
# @raise [Geminize::ValidationError] If validation fails
|
64
|
+
# @return [void]
|
65
|
+
def validate_integer!(value, param_name, min: nil, max: nil)
|
66
|
+
return if value.nil?
|
67
|
+
|
68
|
+
unless value.is_a?(Integer)
|
69
|
+
raise Geminize::ValidationError.new("#{param_name} must be an integer", "INVALID_ARGUMENT")
|
70
|
+
end
|
71
|
+
|
72
|
+
validate_numeric!(value, param_name, min: min, max: max)
|
73
|
+
end
|
74
|
+
|
75
|
+
# Validate that a value is a positive integer
|
76
|
+
# @param value [Object] The value to validate
|
77
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
78
|
+
# @raise [Geminize::ValidationError] If validation fails
|
79
|
+
# @return [void]
|
80
|
+
def validate_positive_integer!(value, param_name)
|
81
|
+
return if value.nil?
|
82
|
+
|
83
|
+
validate_integer!(value, param_name)
|
84
|
+
|
85
|
+
if value <= 0
|
86
|
+
raise Geminize::ValidationError.new("#{param_name} must be positive", "INVALID_ARGUMENT")
|
87
|
+
end
|
88
|
+
end
|
89
|
+
|
90
|
+
# Validate that a value is a float between 0 and 1
|
91
|
+
# @param value [Object] The value to validate
|
92
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
93
|
+
# @raise [Geminize::ValidationError] If validation fails
|
94
|
+
# @return [void]
|
95
|
+
def validate_probability!(value, param_name)
|
96
|
+
return if value.nil?
|
97
|
+
|
98
|
+
validate_numeric!(value, param_name, min: 0.0, max: 1.0)
|
99
|
+
end
|
100
|
+
|
101
|
+
# Validate that a value is an array
|
102
|
+
# @param value [Object] The value to validate
|
103
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
104
|
+
# @raise [Geminize::ValidationError] If validation fails
|
105
|
+
# @return [void]
|
106
|
+
def validate_array!(value, param_name)
|
107
|
+
return if value.nil?
|
108
|
+
|
109
|
+
unless value.is_a?(Array)
|
110
|
+
raise Geminize::ValidationError.new("#{param_name} must be an array", "INVALID_ARGUMENT")
|
111
|
+
end
|
112
|
+
end
|
113
|
+
|
114
|
+
# Validate that all elements of an array are strings
|
115
|
+
# @param value [Array] The array to validate
|
116
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
117
|
+
# @raise [Geminize::ValidationError] If validation fails
|
118
|
+
# @return [void]
|
119
|
+
def validate_string_array!(value, param_name)
|
120
|
+
return if value.nil?
|
121
|
+
|
122
|
+
validate_array!(value, param_name)
|
123
|
+
|
124
|
+
value.each_with_index do |item, index|
|
125
|
+
unless item.is_a?(String)
|
126
|
+
raise Geminize::ValidationError.new("#{param_name}[#{index}] must be a string", "INVALID_ARGUMENT")
|
127
|
+
end
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
# Validate that a value is one of an allowed set of values
|
132
|
+
# @param value [Object] The value to validate
|
133
|
+
# @param param_name [String] The name of the parameter (for error messages)
|
134
|
+
# @param allowed_values [Array] The allowed values
|
135
|
+
# @raise [Geminize::ValidationError] If validation fails
|
136
|
+
# @return [void]
|
137
|
+
def validate_allowed_values!(value, param_name, allowed_values)
|
138
|
+
return if value.nil?
|
139
|
+
|
140
|
+
unless allowed_values.include?(value)
|
141
|
+
allowed_str = allowed_values.map(&:inspect).join(", ")
|
142
|
+
raise Geminize::ValidationError.new(
|
143
|
+
"#{param_name} must be one of: #{allowed_str}",
|
144
|
+
"INVALID_ARGUMENT"
|
145
|
+
)
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
end
|
150
|
+
end
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Geminize
|
4
|
+
# Utility module for vector operations used with embeddings
|
5
|
+
module VectorUtils
|
6
|
+
class << self
|
7
|
+
# Calculate the cosine similarity between two vectors
|
8
|
+
# @param vec1 [Array<Float>] First vector
|
9
|
+
# @param vec2 [Array<Float>] Second vector
|
10
|
+
# @return [Float] Cosine similarity (-1 to 1)
|
11
|
+
# @raise [Geminize::ValidationError] If vectors have different dimensions
|
12
|
+
def cosine_similarity(vec1, vec2)
|
13
|
+
unless vec1.length == vec2.length
|
14
|
+
raise Geminize::ValidationError.new(
|
15
|
+
"Vectors must have the same dimensions (#{vec1.length} vs #{vec2.length})",
|
16
|
+
"INVALID_ARGUMENT"
|
17
|
+
)
|
18
|
+
end
|
19
|
+
|
20
|
+
dot_product = 0.0
|
21
|
+
magnitude1 = 0.0
|
22
|
+
magnitude2 = 0.0
|
23
|
+
|
24
|
+
vec1.zip(vec2).each do |v1, v2|
|
25
|
+
dot_product += v1 * v2
|
26
|
+
magnitude1 += v1 * v1
|
27
|
+
magnitude2 += v2 * v2
|
28
|
+
end
|
29
|
+
|
30
|
+
magnitude1 = Math.sqrt(magnitude1)
|
31
|
+
magnitude2 = Math.sqrt(magnitude2)
|
32
|
+
|
33
|
+
# Guard against division by zero
|
34
|
+
return 0.0 if magnitude1.zero? || magnitude2.zero?
|
35
|
+
|
36
|
+
dot_product / (magnitude1 * magnitude2)
|
37
|
+
end
|
38
|
+
|
39
|
+
# Calculate the Euclidean distance between two vectors
|
40
|
+
# @param vec1 [Array<Float>] First vector
|
41
|
+
# @param vec2 [Array<Float>] Second vector
|
42
|
+
# @return [Float] Euclidean distance
|
43
|
+
# @raise [Geminize::ValidationError] If vectors have different dimensions
|
44
|
+
def euclidean_distance(vec1, vec2)
|
45
|
+
unless vec1.length == vec2.length
|
46
|
+
raise Geminize::ValidationError.new(
|
47
|
+
"Vectors must have the same dimensions (#{vec1.length} vs #{vec2.length})",
|
48
|
+
"INVALID_ARGUMENT"
|
49
|
+
)
|
50
|
+
end
|
51
|
+
|
52
|
+
sum_square_diff = 0.0
|
53
|
+
vec1.zip(vec2).each do |v1, v2|
|
54
|
+
diff = v1 - v2
|
55
|
+
sum_square_diff += diff * diff
|
56
|
+
end
|
57
|
+
|
58
|
+
Math.sqrt(sum_square_diff)
|
59
|
+
end
|
60
|
+
|
61
|
+
# Calculate the dot product of two vectors
|
62
|
+
# @param vec1 [Array<Float>] First vector
|
63
|
+
# @param vec2 [Array<Float>] Second vector
|
64
|
+
# @return [Float] Dot product
|
65
|
+
# @raise [Geminize::ValidationError] If vectors have different dimensions
|
66
|
+
def dot_product(vec1, vec2)
|
67
|
+
unless vec1.length == vec2.length
|
68
|
+
raise Geminize::ValidationError.new(
|
69
|
+
"Vectors must have the same dimensions (#{vec1.length} vs #{vec2.length})",
|
70
|
+
"INVALID_ARGUMENT"
|
71
|
+
)
|
72
|
+
end
|
73
|
+
|
74
|
+
product = 0.0
|
75
|
+
vec1.zip(vec2).each do |v1, v2|
|
76
|
+
product += v1 * v2
|
77
|
+
end
|
78
|
+
|
79
|
+
product
|
80
|
+
end
|
81
|
+
|
82
|
+
# Normalize a vector to unit length
|
83
|
+
# @param vec [Array<Float>] Vector to normalize
|
84
|
+
# @return [Array<Float>] Normalized vector
|
85
|
+
def normalize(vec)
|
86
|
+
magnitude = 0.0
|
87
|
+
vec.each do |v|
|
88
|
+
magnitude += v * v
|
89
|
+
end
|
90
|
+
magnitude = Math.sqrt(magnitude)
|
91
|
+
|
92
|
+
# Handle zero magnitude vector
|
93
|
+
return vec.map { 0.0 } if magnitude.zero?
|
94
|
+
|
95
|
+
vec.map { |v| v / magnitude }
|
96
|
+
end
|
97
|
+
|
98
|
+
# Average multiple vectors
|
99
|
+
# @param vectors [Array<Array<Float>>] Array of vectors
|
100
|
+
# @return [Array<Float>] Average vector
|
101
|
+
# @raise [Geminize::ValidationError] If vectors have different dimensions or no vectors provided
|
102
|
+
def average_vectors(vectors)
|
103
|
+
if vectors.empty?
|
104
|
+
raise Geminize::ValidationError.new(
|
105
|
+
"Cannot average an empty array of vectors",
|
106
|
+
"INVALID_ARGUMENT"
|
107
|
+
)
|
108
|
+
end
|
109
|
+
|
110
|
+
# Check all vectors have same dimensionality
|
111
|
+
dim = vectors.first.length
|
112
|
+
vectors.each_with_index do |vec, i|
|
113
|
+
unless vec.length == dim
|
114
|
+
raise Geminize::ValidationError.new(
|
115
|
+
"All vectors must have the same dimensions (expected #{dim}, got #{vec.length} at index #{i})",
|
116
|
+
"INVALID_ARGUMENT"
|
117
|
+
)
|
118
|
+
end
|
119
|
+
end
|
120
|
+
|
121
|
+
# Calculate average
|
122
|
+
avg = Array.new(dim, 0.0)
|
123
|
+
vectors.each do |vec|
|
124
|
+
vec.each_with_index do |v, i|
|
125
|
+
avg[i] += v
|
126
|
+
end
|
127
|
+
end
|
128
|
+
|
129
|
+
avg.map { |sum| sum / vectors.length }
|
130
|
+
end
|
131
|
+
|
132
|
+
# Find the most similar vectors to a target vector
|
133
|
+
# @param target [Array<Float>] Target vector
|
134
|
+
# @param vectors [Array<Array<Float>>] Vectors to compare against
|
135
|
+
# @param top_k [Integer, nil] Number of most similar vectors to return
|
136
|
+
# @param metric [Symbol] Distance metric to use (:cosine or :euclidean)
|
137
|
+
# @return [Array<Hash>] Array of {index:, similarity:} hashes sorted by similarity
|
138
|
+
def most_similar(target, vectors, top_k = nil, metric = :cosine)
|
139
|
+
similarities = []
|
140
|
+
|
141
|
+
vectors.each_with_index do |vec, i|
|
142
|
+
similarity = case metric
|
143
|
+
when :cosine
|
144
|
+
cosine_similarity(target, vec)
|
145
|
+
when :euclidean
|
146
|
+
# Convert to similarity (higher is more similar)
|
147
|
+
1.0 / (1.0 + euclidean_distance(target, vec))
|
148
|
+
else
|
149
|
+
raise Geminize::ValidationError.new(
|
150
|
+
"Unknown metric: #{metric}. Supported metrics: :cosine, :euclidean",
|
151
|
+
"INVALID_ARGUMENT"
|
152
|
+
)
|
153
|
+
end
|
154
|
+
|
155
|
+
similarities << {index: i, similarity: similarity}
|
156
|
+
end
|
157
|
+
|
158
|
+
# Sort by similarity (descending)
|
159
|
+
sorted = similarities.sort_by { |s| -s[:similarity] }
|
160
|
+
top_k ? sorted.take(top_k) : sorted
|
161
|
+
end
|
162
|
+
end
|
163
|
+
end
|
164
|
+
end
|