vectra-client 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.codecov.yml +31 -0
- data/.rspec +4 -0
- data/.rubocop.yml +183 -0
- data/.ruby-version +1 -0
- data/CHANGELOG.md +88 -0
- data/CODE_OF_CONDUCT.md +127 -0
- data/CONTRIBUTING.md +239 -0
- data/LICENSE +21 -0
- data/README.md +456 -0
- data/Rakefile +34 -0
- data/SECURITY.md +196 -0
- data/lib/vectra/client.rb +304 -0
- data/lib/vectra/configuration.rb +169 -0
- data/lib/vectra/errors.rb +73 -0
- data/lib/vectra/providers/base.rb +265 -0
- data/lib/vectra/providers/pgvector/connection.rb +75 -0
- data/lib/vectra/providers/pgvector/index_management.rb +122 -0
- data/lib/vectra/providers/pgvector/sql_helpers.rb +115 -0
- data/lib/vectra/providers/pgvector.rb +297 -0
- data/lib/vectra/providers/pinecone.rb +308 -0
- data/lib/vectra/providers/qdrant.rb +48 -0
- data/lib/vectra/providers/weaviate.rb +48 -0
- data/lib/vectra/query_result.rb +257 -0
- data/lib/vectra/vector.rb +155 -0
- data/lib/vectra/version.rb +5 -0
- data/lib/vectra.rb +133 -0
- metadata +226 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "faraday"
|
|
4
|
+
require "faraday/retry"
|
|
5
|
+
require "json"
|
|
6
|
+
|
|
7
|
+
module Vectra
|
|
8
|
+
module Providers
|
|
9
|
+
# Abstract base class for vector database providers
|
|
10
|
+
#
|
|
11
|
+
# All provider implementations must inherit from this class
|
|
12
|
+
# and implement the required methods.
|
|
13
|
+
#
|
|
14
|
+
class Base
|
|
15
|
+
attr_reader :config
|
|
16
|
+
|
|
17
|
+
# Initialize the provider
|
|
18
|
+
#
|
|
19
|
+
# @param config [Configuration] the configuration object
|
|
20
|
+
def initialize(config)
|
|
21
|
+
@config = config
|
|
22
|
+
validate_config!
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
# Upsert vectors into an index
|
|
26
|
+
#
|
|
27
|
+
# @param index [String] the index/collection name
|
|
28
|
+
# @param vectors [Array<Hash, Vector>] vectors to upsert
|
|
29
|
+
# @param namespace [String, nil] optional namespace
|
|
30
|
+
# @return [Hash] upsert response
|
|
31
|
+
def upsert(index:, vectors:, namespace: nil)
|
|
32
|
+
raise NotImplementedError, "#{self.class} must implement #upsert"
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Query vectors by similarity
|
|
36
|
+
#
|
|
37
|
+
# @param index [String] the index/collection name
|
|
38
|
+
# @param vector [Array<Float>] query vector
|
|
39
|
+
# @param top_k [Integer] number of results to return
|
|
40
|
+
# @param namespace [String, nil] optional namespace
|
|
41
|
+
# @param filter [Hash, nil] metadata filter
|
|
42
|
+
# @param include_values [Boolean] include vector values in response
|
|
43
|
+
# @param include_metadata [Boolean] include metadata in response
|
|
44
|
+
# @return [QueryResult] query results
|
|
45
|
+
def query(index:, vector:, top_k: 10, namespace: nil, filter: nil,
|
|
46
|
+
include_values: false, include_metadata: true)
|
|
47
|
+
raise NotImplementedError, "#{self.class} must implement #query"
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Fetch vectors by IDs
|
|
51
|
+
#
|
|
52
|
+
# @param index [String] the index/collection name
|
|
53
|
+
# @param ids [Array<String>] vector IDs to fetch
|
|
54
|
+
# @param namespace [String, nil] optional namespace
|
|
55
|
+
# @return [Hash<String, Vector>] fetched vectors
|
|
56
|
+
def fetch(index:, ids:, namespace: nil)
|
|
57
|
+
raise NotImplementedError, "#{self.class} must implement #fetch"
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
# Update a vector's metadata
|
|
61
|
+
#
|
|
62
|
+
# @param index [String] the index/collection name
|
|
63
|
+
# @param id [String] vector ID
|
|
64
|
+
# @param metadata [Hash] new metadata
|
|
65
|
+
# @param namespace [String, nil] optional namespace
|
|
66
|
+
# @return [Hash] update response
|
|
67
|
+
def update(index:, id:, metadata:, namespace: nil)
|
|
68
|
+
raise NotImplementedError, "#{self.class} must implement #update"
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
# Delete vectors
|
|
72
|
+
#
|
|
73
|
+
# @param index [String] the index/collection name
|
|
74
|
+
# @param ids [Array<String>, nil] vector IDs to delete
|
|
75
|
+
# @param namespace [String, nil] optional namespace
|
|
76
|
+
# @param filter [Hash, nil] delete by metadata filter
|
|
77
|
+
# @param delete_all [Boolean] delete all vectors
|
|
78
|
+
# @return [Hash] delete response
|
|
79
|
+
def delete(index:, ids: nil, namespace: nil, filter: nil, delete_all: false)
|
|
80
|
+
raise NotImplementedError, "#{self.class} must implement #delete"
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# List indexes/collections
|
|
84
|
+
#
|
|
85
|
+
# @return [Array<Hash>] list of indexes
|
|
86
|
+
def list_indexes
|
|
87
|
+
raise NotImplementedError, "#{self.class} must implement #list_indexes"
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# Describe an index
|
|
91
|
+
#
|
|
92
|
+
# @param index [String] the index name
|
|
93
|
+
# @return [Hash] index details
|
|
94
|
+
def describe_index(index:)
|
|
95
|
+
raise NotImplementedError, "#{self.class} must implement #describe_index"
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# Get index statistics
|
|
99
|
+
#
|
|
100
|
+
# @param index [String] the index name
|
|
101
|
+
# @param namespace [String, nil] optional namespace
|
|
102
|
+
# @return [Hash] index statistics
|
|
103
|
+
def stats(index:, namespace: nil)
|
|
104
|
+
raise NotImplementedError, "#{self.class} must implement #stats"
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
# Provider name
|
|
108
|
+
#
|
|
109
|
+
# @return [Symbol]
|
|
110
|
+
def provider_name
|
|
111
|
+
raise NotImplementedError, "#{self.class} must implement #provider_name"
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
protected
|
|
115
|
+
|
|
116
|
+
# Build HTTP connection with retry logic
|
|
117
|
+
#
|
|
118
|
+
# @param base_url [String] base URL for the API
|
|
119
|
+
# @param headers [Hash] request headers
|
|
120
|
+
# @return [Faraday::Connection]
|
|
121
|
+
def build_connection(base_url, headers = {})
|
|
122
|
+
Faraday.new(url: base_url) do |conn|
|
|
123
|
+
conn.request :json
|
|
124
|
+
conn.response :json, content_type: /\bjson$/
|
|
125
|
+
|
|
126
|
+
conn.request :retry, {
|
|
127
|
+
max: config.max_retries,
|
|
128
|
+
interval: config.retry_delay,
|
|
129
|
+
interval_randomness: 0.5,
|
|
130
|
+
backoff_factor: 2,
|
|
131
|
+
retry_statuses: [429, 500, 502, 503, 504],
|
|
132
|
+
exceptions: [
|
|
133
|
+
Faraday::TimeoutError,
|
|
134
|
+
Faraday::ConnectionFailed
|
|
135
|
+
]
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
conn.headers = default_headers.merge(headers)
|
|
139
|
+
conn.options.timeout = config.timeout
|
|
140
|
+
conn.options.open_timeout = config.open_timeout
|
|
141
|
+
|
|
142
|
+
conn.adapter Faraday.default_adapter
|
|
143
|
+
end
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
# Default headers for all requests
|
|
147
|
+
#
|
|
148
|
+
# @return [Hash]
|
|
149
|
+
def default_headers
|
|
150
|
+
{
|
|
151
|
+
"Content-Type" => "application/json",
|
|
152
|
+
"Accept" => "application/json",
|
|
153
|
+
"User-Agent" => "vectra-ruby/#{Vectra::VERSION}"
|
|
154
|
+
}
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
# Normalize vectors for API request
|
|
158
|
+
#
|
|
159
|
+
# @param vectors [Array<Hash, Vector>] vectors to normalize
|
|
160
|
+
# @return [Array<Hash>]
|
|
161
|
+
def normalize_vectors(vectors)
|
|
162
|
+
vectors.map do |vec|
|
|
163
|
+
case vec
|
|
164
|
+
when Vector
|
|
165
|
+
vec.to_h
|
|
166
|
+
when Hash
|
|
167
|
+
normalize_vector_hash(vec)
|
|
168
|
+
else
|
|
169
|
+
raise ValidationError, "Vector must be a Hash or Vectra::Vector"
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
# Normalize a single vector hash
|
|
175
|
+
#
|
|
176
|
+
# @param hash [Hash] vector hash
|
|
177
|
+
# @return [Hash]
|
|
178
|
+
def normalize_vector_hash(hash)
|
|
179
|
+
hash = hash.transform_keys(&:to_sym)
|
|
180
|
+
|
|
181
|
+
result = {
|
|
182
|
+
id: hash[:id].to_s,
|
|
183
|
+
values: hash[:values].map(&:to_f)
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
result[:metadata] = hash[:metadata] if hash[:metadata]
|
|
187
|
+
result[:sparse_values] = hash[:sparse_values] if hash[:sparse_values]
|
|
188
|
+
|
|
189
|
+
result
|
|
190
|
+
end
|
|
191
|
+
|
|
192
|
+
# Handle API errors
|
|
193
|
+
#
|
|
194
|
+
# @param response [Faraday::Response] the response
|
|
195
|
+
# @raise [Error] appropriate error for the response
|
|
196
|
+
def handle_error(response)
|
|
197
|
+
status = response.status
|
|
198
|
+
body = response.body
|
|
199
|
+
|
|
200
|
+
error_message = extract_error_message(body)
|
|
201
|
+
|
|
202
|
+
case status
|
|
203
|
+
when 400
|
|
204
|
+
raise ValidationError.new(error_message, response: response)
|
|
205
|
+
when 401
|
|
206
|
+
raise AuthenticationError.new(error_message, response: response)
|
|
207
|
+
when 403
|
|
208
|
+
raise AuthenticationError.new("Access forbidden: #{error_message}", response: response)
|
|
209
|
+
when 404
|
|
210
|
+
raise NotFoundError.new(error_message, response: response)
|
|
211
|
+
when 429
|
|
212
|
+
retry_after = response.headers["retry-after"]&.to_i
|
|
213
|
+
raise RateLimitError.new(error_message, retry_after: retry_after, response: response)
|
|
214
|
+
when 500..599
|
|
215
|
+
raise ServerError.new(error_message, status_code: status, response: response)
|
|
216
|
+
else
|
|
217
|
+
raise Error.new("Request failed with status #{status}: #{error_message}", response: response)
|
|
218
|
+
end
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
# Extract error message from response body
|
|
222
|
+
#
|
|
223
|
+
# @param body [Hash, String, nil] response body
|
|
224
|
+
# @return [String]
|
|
225
|
+
def extract_error_message(body)
|
|
226
|
+
case body
|
|
227
|
+
when Hash
|
|
228
|
+
body["message"] || body["error"] || body.to_s
|
|
229
|
+
when String
|
|
230
|
+
body
|
|
231
|
+
else
|
|
232
|
+
"Unknown error"
|
|
233
|
+
end
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
# Log debug information
|
|
237
|
+
#
|
|
238
|
+
# @param message [String] message to log
|
|
239
|
+
# @param data [Hash] optional data to log
|
|
240
|
+
def log_debug(message, data = nil)
|
|
241
|
+
return unless config.logger
|
|
242
|
+
|
|
243
|
+
config.logger.debug("[Vectra] #{message}")
|
|
244
|
+
config.logger.debug("[Vectra] #{data.inspect}") if data
|
|
245
|
+
end
|
|
246
|
+
|
|
247
|
+
# Log error information
|
|
248
|
+
#
|
|
249
|
+
# @param message [String] message to log
|
|
250
|
+
# @param error [Exception, nil] optional error
|
|
251
|
+
def log_error(message, error = nil)
|
|
252
|
+
return unless config.logger
|
|
253
|
+
|
|
254
|
+
config.logger.error("[Vectra] #{message}")
|
|
255
|
+
config.logger.error("[Vectra] #{error.class}: #{error.message}") if error
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
private
|
|
259
|
+
|
|
260
|
+
def validate_config!
|
|
261
|
+
config.validate!
|
|
262
|
+
end
|
|
263
|
+
end
|
|
264
|
+
end
|
|
265
|
+
end
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Vectra
|
|
4
|
+
module Providers
|
|
5
|
+
class Pgvector < Base
|
|
6
|
+
# Connection management for pgvector provider
|
|
7
|
+
module Connection
|
|
8
|
+
private
|
|
9
|
+
|
|
10
|
+
# Get or create database connection
|
|
11
|
+
def connection
|
|
12
|
+
@connection ||= begin
|
|
13
|
+
require "pg"
|
|
14
|
+
|
|
15
|
+
conn_params = parse_connection_params
|
|
16
|
+
PG.connect(conn_params)
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Parse connection parameters from config
|
|
21
|
+
def parse_connection_params
|
|
22
|
+
if config.host&.start_with?("postgres://", "postgresql://")
|
|
23
|
+
{ conninfo: config.host }
|
|
24
|
+
else
|
|
25
|
+
{
|
|
26
|
+
host: config.host || "localhost",
|
|
27
|
+
port: config.environment&.to_i || 5432,
|
|
28
|
+
dbname: extract_database_name,
|
|
29
|
+
user: extract_username,
|
|
30
|
+
password: config.api_key
|
|
31
|
+
}.compact
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Extract database name from host or use default
|
|
36
|
+
def extract_database_name
|
|
37
|
+
if config.host&.include?("/")
|
|
38
|
+
config.host.split("/").last
|
|
39
|
+
else
|
|
40
|
+
"postgres"
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
# Extract username
|
|
45
|
+
def extract_username
|
|
46
|
+
ENV.fetch("PGUSER", "postgres")
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
# Execute SQL with parameters
|
|
50
|
+
def execute(sql, params = [])
|
|
51
|
+
log_debug("Executing SQL", { sql: sql, params: params })
|
|
52
|
+
connection.exec_params(sql, params)
|
|
53
|
+
rescue PG::Error => e
|
|
54
|
+
handle_pg_error(e)
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
# Handle PostgreSQL errors
|
|
58
|
+
def handle_pg_error(error)
|
|
59
|
+
case error
|
|
60
|
+
when PG::UndefinedTable
|
|
61
|
+
raise NotFoundError, "not found"
|
|
62
|
+
when PG::InvalidPassword
|
|
63
|
+
raise AuthenticationError, "authentication failed"
|
|
64
|
+
when PG::ConnectionBad
|
|
65
|
+
raise ConnectionError, "connection failed"
|
|
66
|
+
when PG::UniqueViolation, PG::CheckViolation
|
|
67
|
+
raise ValidationError, error.message
|
|
68
|
+
else
|
|
69
|
+
raise ServerError.new(error.message, status_code: 500)
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
end
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Vectra
|
|
4
|
+
module Providers
|
|
5
|
+
class Pgvector < Base
|
|
6
|
+
# Index management methods for pgvector provider
|
|
7
|
+
module IndexManagement
|
|
8
|
+
INDEX_OPS = {
|
|
9
|
+
"cosine" => "vector_cosine_ops",
|
|
10
|
+
"euclidean" => "vector_l2_ops",
|
|
11
|
+
"inner_product" => "vector_ip_ops"
|
|
12
|
+
}.freeze
|
|
13
|
+
|
|
14
|
+
# Create a new index (table with vector column)
|
|
15
|
+
def create_index(name:, dimension:, metric: DEFAULT_METRIC)
|
|
16
|
+
validate_metric!(metric)
|
|
17
|
+
create_table_with_vector(name, dimension)
|
|
18
|
+
create_ivfflat_index(name, metric)
|
|
19
|
+
store_metric_comment(name, metric)
|
|
20
|
+
|
|
21
|
+
@table_cache[name] = { dimension: dimension, metric: metric }
|
|
22
|
+
log_debug("Created index #{name}")
|
|
23
|
+
describe_index(index: name)
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
# Delete an index (drop table)
|
|
27
|
+
def delete_index(name:)
|
|
28
|
+
execute("DROP TABLE IF EXISTS #{quote_ident(name)} CASCADE")
|
|
29
|
+
@table_cache.delete(name)
|
|
30
|
+
log_debug("Deleted index #{name}")
|
|
31
|
+
{ deleted: true }
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
private
|
|
35
|
+
|
|
36
|
+
# Validate metric type
|
|
37
|
+
def validate_metric!(metric)
|
|
38
|
+
return if DISTANCE_FUNCTIONS.key?(metric)
|
|
39
|
+
|
|
40
|
+
raise ValidationError, "Invalid metric '#{metric}'. Supported: #{DISTANCE_FUNCTIONS.keys.join(', ')}"
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
# Extract dimension from vector type string
|
|
44
|
+
def extract_dimension_from_type(type_info)
|
|
45
|
+
match = type_info.match(/vector\((\d+)\)/)
|
|
46
|
+
return nil unless match
|
|
47
|
+
|
|
48
|
+
match.captures.first.to_i
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
# Create table with vector column
|
|
52
|
+
def create_table_with_vector(name, dimension)
|
|
53
|
+
execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
54
|
+
|
|
55
|
+
sql = <<~SQL
|
|
56
|
+
CREATE TABLE IF NOT EXISTS #{quote_ident(name)} (
|
|
57
|
+
id TEXT PRIMARY KEY,
|
|
58
|
+
embedding vector(#{dimension.to_i}),
|
|
59
|
+
metadata JSONB DEFAULT '{}',
|
|
60
|
+
namespace TEXT DEFAULT '',
|
|
61
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
62
|
+
)
|
|
63
|
+
SQL
|
|
64
|
+
execute(sql)
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
# Create IVFFlat index for similarity search
|
|
68
|
+
def create_ivfflat_index(name, metric)
|
|
69
|
+
index_op = INDEX_OPS[metric]
|
|
70
|
+
idx_name = "#{name}_embedding_idx"
|
|
71
|
+
idx_sql = <<~SQL
|
|
72
|
+
CREATE INDEX IF NOT EXISTS #{quote_ident(idx_name)}
|
|
73
|
+
ON #{quote_ident(name)}
|
|
74
|
+
USING ivfflat (embedding #{index_op})
|
|
75
|
+
WITH (lists = 100)
|
|
76
|
+
SQL
|
|
77
|
+
|
|
78
|
+
execute(idx_sql)
|
|
79
|
+
rescue PG::Error => e
|
|
80
|
+
log_debug("Index creation deferred: #{e.message}")
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# Store metric in table comment
|
|
84
|
+
def store_metric_comment(name, metric)
|
|
85
|
+
execute("COMMENT ON TABLE #{quote_ident(name)} IS #{escape_literal("vectra:metric=#{metric}")}")
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# Ensure table exists
|
|
89
|
+
def ensure_table_exists!(index)
|
|
90
|
+
return if @table_cache.key?(index)
|
|
91
|
+
|
|
92
|
+
sql = <<~SQL
|
|
93
|
+
SELECT EXISTS (
|
|
94
|
+
SELECT FROM information_schema.tables
|
|
95
|
+
WHERE table_schema = 'public' AND table_name = $1
|
|
96
|
+
)
|
|
97
|
+
SQL
|
|
98
|
+
|
|
99
|
+
result = execute(sql, [index])
|
|
100
|
+
exists_value = result.first["exists"]
|
|
101
|
+
exists = [true, "t"].include?(exists_value)
|
|
102
|
+
|
|
103
|
+
raise NotFoundError, "Index '#{index}' not found" unless exists
|
|
104
|
+
|
|
105
|
+
@table_cache[index] = true
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
# Get metric for a table from stored comment
|
|
109
|
+
def table_metric(index)
|
|
110
|
+
return @table_cache.dig(index, :metric) if @table_cache[index].is_a?(Hash)
|
|
111
|
+
|
|
112
|
+
sql = "SELECT obj_description($1::regclass, 'pg_class') as comment"
|
|
113
|
+
result = execute(sql, [index])
|
|
114
|
+
comment = result.first&.fetch("comment", nil)
|
|
115
|
+
|
|
116
|
+
metric = comment&.include?("vectra:metric=") ? comment.match(/vectra:metric=(\w+)/)&.captures&.first : nil
|
|
117
|
+
metric || DEFAULT_METRIC
|
|
118
|
+
end
|
|
119
|
+
end
|
|
120
|
+
end
|
|
121
|
+
end
|
|
122
|
+
end
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Vectra
|
|
4
|
+
module Providers
|
|
5
|
+
class Pgvector < Base
|
|
6
|
+
# SQL helper methods for pgvector provider
|
|
7
|
+
module SqlHelpers
|
|
8
|
+
private
|
|
9
|
+
|
|
10
|
+
# Quote identifier to prevent SQL injection
|
|
11
|
+
def quote_ident(name)
|
|
12
|
+
connection.quote_ident(name)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
# Escape literal string
|
|
16
|
+
def escape_literal(str)
|
|
17
|
+
connection.escape_literal(str)
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Format vector for PostgreSQL
|
|
21
|
+
def format_vector(values)
|
|
22
|
+
"[#{values.map(&:to_f).join(',')}]"
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
# Parse vector from PostgreSQL string format
|
|
26
|
+
def parse_vector(str)
|
|
27
|
+
return nil unless str
|
|
28
|
+
|
|
29
|
+
str.gsub(/[\[\]]/, "").split(",").map(&:to_f)
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
# Parse JSON from PostgreSQL
|
|
33
|
+
def parse_json(str)
|
|
34
|
+
return {} unless str
|
|
35
|
+
|
|
36
|
+
case str
|
|
37
|
+
when String
|
|
38
|
+
JSON.parse(str)
|
|
39
|
+
when Hash
|
|
40
|
+
str
|
|
41
|
+
else
|
|
42
|
+
{}
|
|
43
|
+
end
|
|
44
|
+
rescue JSON::ParserError
|
|
45
|
+
{}
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
# Build SQL for vector similarity query
|
|
49
|
+
def build_query_sql(index:, vector_literal:, distance_op:, top_k:,
|
|
50
|
+
namespace:, filter:, include_values:, include_metadata:)
|
|
51
|
+
select_cols = build_select_columns(vector_literal, distance_op, include_values, include_metadata)
|
|
52
|
+
where_clauses = build_where_clauses(namespace, filter)
|
|
53
|
+
|
|
54
|
+
sql = "SELECT #{select_cols.join(', ')} FROM #{quote_ident(index)}"
|
|
55
|
+
sql += " WHERE #{where_clauses.join(' AND ')}" if where_clauses.any?
|
|
56
|
+
sql += " ORDER BY embedding #{distance_op} '#{vector_literal}'::vector"
|
|
57
|
+
sql += " LIMIT #{top_k.to_i}"
|
|
58
|
+
sql
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# Build SELECT columns for query
|
|
62
|
+
def build_select_columns(vector_literal, distance_op, include_values, include_metadata)
|
|
63
|
+
cols = ["id", "1 - (embedding #{distance_op} '#{vector_literal}'::vector) as score"]
|
|
64
|
+
cols << "embedding" if include_values
|
|
65
|
+
cols << "metadata" if include_metadata
|
|
66
|
+
cols
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# Build WHERE clauses for query
|
|
70
|
+
def build_where_clauses(namespace, filter)
|
|
71
|
+
clauses = []
|
|
72
|
+
clauses << "namespace = #{escape_literal(namespace)}" if namespace
|
|
73
|
+
|
|
74
|
+
filter&.each do |key, value|
|
|
75
|
+
json_path = "metadata->>#{escape_literal(key.to_s)}"
|
|
76
|
+
clauses << "#{json_path} = #{escape_literal(value.to_s)}"
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
clauses
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Build match hash from database row
|
|
83
|
+
def build_match_from_row(row, include_values, include_metadata)
|
|
84
|
+
match = { id: row["id"], score: row["score"].to_f }
|
|
85
|
+
match[:values] = parse_vector(row["embedding"]) if include_values && row["embedding"]
|
|
86
|
+
match[:metadata] = parse_json(row["metadata"]) if include_metadata && row["metadata"]
|
|
87
|
+
match
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# Build SQL for filter-based delete
|
|
91
|
+
def build_filter_delete_sql(index, filter, namespace)
|
|
92
|
+
sql = "DELETE FROM #{quote_ident(index)} WHERE "
|
|
93
|
+
clauses = []
|
|
94
|
+
params = []
|
|
95
|
+
param_idx = 1
|
|
96
|
+
|
|
97
|
+
filter.each do |key, value|
|
|
98
|
+
clauses << "metadata->>$#{param_idx} = $#{param_idx + 1}"
|
|
99
|
+
params << key.to_s
|
|
100
|
+
params << value.to_s
|
|
101
|
+
param_idx += 2
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
if namespace
|
|
105
|
+
clauses << "namespace = $#{param_idx}"
|
|
106
|
+
params << namespace
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
sql += clauses.join(" AND ")
|
|
110
|
+
[sql, params]
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
end
|
|
115
|
+
end
|