chromadb-experimental 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. checksums.yaml +7 -0
  2. data/lib/chromadb/admin_client.rb +6 -0
  3. data/lib/chromadb/client.rb +317 -0
  4. data/lib/chromadb/collection.rb +573 -0
  5. data/lib/chromadb/embedding_functions/chroma_bm25.rb +459 -0
  6. data/lib/chromadb/embedding_functions/chroma_cloud_qwen.rb +139 -0
  7. data/lib/chromadb/embedding_functions/chroma_cloud_splade.rb +121 -0
  8. data/lib/chromadb/embedding_functions.rb +121 -0
  9. data/lib/chromadb/errors.rb +120 -0
  10. data/lib/chromadb/http_client.rb +142 -0
  11. data/lib/chromadb/openapi/lib/chromadb/api/default_api.rb +2349 -0
  12. data/lib/chromadb/openapi/lib/chromadb/api_client.rb +392 -0
  13. data/lib/chromadb/openapi/lib/chromadb/api_error.rb +58 -0
  14. data/lib/chromadb/openapi/lib/chromadb/configuration.rb +295 -0
  15. data/lib/chromadb/openapi/lib/chromadb/models/add_collection_records_payload.rb +260 -0
  16. data/lib/chromadb/openapi/lib/chromadb/models/attach_function_request.rb +250 -0
  17. data/lib/chromadb/openapi/lib/chromadb/models/attach_function_response.rb +235 -0
  18. data/lib/chromadb/openapi/lib/chromadb/models/attached_function_api_response.rb +361 -0
  19. data/lib/chromadb/openapi/lib/chromadb/models/attached_function_info.rb +240 -0
  20. data/lib/chromadb/openapi/lib/chromadb/models/bool_inverted_index_type.rb +229 -0
  21. data/lib/chromadb/openapi/lib/chromadb/models/bool_value_type.rb +221 -0
  22. data/lib/chromadb/openapi/lib/chromadb/models/checklist_response.rb +245 -0
  23. data/lib/chromadb/openapi/lib/chromadb/models/collection.rb +315 -0
  24. data/lib/chromadb/openapi/lib/chromadb/models/collection_configuration.rb +240 -0
  25. data/lib/chromadb/openapi/lib/chromadb/models/create_collection_payload.rb +260 -0
  26. data/lib/chromadb/openapi/lib/chromadb/models/create_database_payload.rb +220 -0
  27. data/lib/chromadb/openapi/lib/chromadb/models/create_tenant_payload.rb +220 -0
  28. data/lib/chromadb/openapi/lib/chromadb/models/database.rb +240 -0
  29. data/lib/chromadb/openapi/lib/chromadb/models/detach_function_request.rb +221 -0
  30. data/lib/chromadb/openapi/lib/chromadb/models/detach_function_response.rb +220 -0
  31. data/lib/chromadb/openapi/lib/chromadb/models/embedding_function_new_configuration.rb +230 -0
  32. data/lib/chromadb/openapi/lib/chromadb/models/error_response.rb +230 -0
  33. data/lib/chromadb/openapi/lib/chromadb/models/float_inverted_index_type.rb +229 -0
  34. data/lib/chromadb/openapi/lib/chromadb/models/float_list_value_type.rb +221 -0
  35. data/lib/chromadb/openapi/lib/chromadb/models/float_value_type.rb +221 -0
  36. data/lib/chromadb/openapi/lib/chromadb/models/fork_collection_payload.rb +220 -0
  37. data/lib/chromadb/openapi/lib/chromadb/models/fts_index_type.rb +229 -0
  38. data/lib/chromadb/openapi/lib/chromadb/models/get_attached_function_response.rb +224 -0
  39. data/lib/chromadb/openapi/lib/chromadb/models/get_response.rb +270 -0
  40. data/lib/chromadb/openapi/lib/chromadb/models/get_tenant_response.rb +230 -0
  41. data/lib/chromadb/openapi/lib/chromadb/models/get_user_identity_response.rb +246 -0
  42. data/lib/chromadb/openapi/lib/chromadb/models/heartbeat_response.rb +235 -0
  43. data/lib/chromadb/openapi/lib/chromadb/models/hnsw_configuration.rb +330 -0
  44. data/lib/chromadb/openapi/lib/chromadb/models/hnsw_index_config.rb +371 -0
  45. data/lib/chromadb/openapi/lib/chromadb/models/include.rb +210 -0
  46. data/lib/chromadb/openapi/lib/chromadb/models/int_inverted_index_type.rb +229 -0
  47. data/lib/chromadb/openapi/lib/chromadb/models/int_value_type.rb +221 -0
  48. data/lib/chromadb/openapi/lib/chromadb/models/query_response.rb +280 -0
  49. data/lib/chromadb/openapi/lib/chromadb/models/raw_where_fields.rb +230 -0
  50. data/lib/chromadb/openapi/lib/chromadb/models/schema.rb +258 -0
  51. data/lib/chromadb/openapi/lib/chromadb/models/search_payload.rb +256 -0
  52. data/lib/chromadb/openapi/lib/chromadb/models/search_payload_filter.rb +230 -0
  53. data/lib/chromadb/openapi/lib/chromadb/models/search_payload_group_by.rb +230 -0
  54. data/lib/chromadb/openapi/lib/chromadb/models/search_payload_limit.rb +230 -0
  55. data/lib/chromadb/openapi/lib/chromadb/models/search_payload_select.rb +220 -0
  56. data/lib/chromadb/openapi/lib/chromadb/models/search_request_payload.rb +220 -0
  57. data/lib/chromadb/openapi/lib/chromadb/models/search_response.rb +270 -0
  58. data/lib/chromadb/openapi/lib/chromadb/models/space.rb +210 -0
  59. data/lib/chromadb/openapi/lib/chromadb/models/spann_configuration.rb +420 -0
  60. data/lib/chromadb/openapi/lib/chromadb/models/spann_index_config.rb +536 -0
  61. data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector.rb +244 -0
  62. data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector_index_config.rb +242 -0
  63. data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector_index_type.rb +234 -0
  64. data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector_value_type.rb +221 -0
  65. data/lib/chromadb/openapi/lib/chromadb/models/string_inverted_index_type.rb +229 -0
  66. data/lib/chromadb/openapi/lib/chromadb/models/string_value_type.rb +231 -0
  67. data/lib/chromadb/openapi/lib/chromadb/models/update_collection_configuration.rb +240 -0
  68. data/lib/chromadb/openapi/lib/chromadb/models/update_collection_payload.rb +240 -0
  69. data/lib/chromadb/openapi/lib/chromadb/models/update_collection_records_payload.rb +260 -0
  70. data/lib/chromadb/openapi/lib/chromadb/models/update_hnsw_configuration.rb +345 -0
  71. data/lib/chromadb/openapi/lib/chromadb/models/update_spann_configuration.rb +260 -0
  72. data/lib/chromadb/openapi/lib/chromadb/models/update_tenant_payload.rb +220 -0
  73. data/lib/chromadb/openapi/lib/chromadb/models/upsert_collection_records_payload.rb +260 -0
  74. data/lib/chromadb/openapi/lib/chromadb/models/value_types.rb +271 -0
  75. data/lib/chromadb/openapi/lib/chromadb/models/vector_index_config.rb +261 -0
  76. data/lib/chromadb/openapi/lib/chromadb/models/vector_index_type.rb +234 -0
  77. data/lib/chromadb/openapi/lib/chromadb/version.rb +15 -0
  78. data/lib/chromadb/openapi/lib/chromadb.rb +102 -0
  79. data/lib/chromadb/openapi.rb +6 -0
  80. data/lib/chromadb/schema.rb +744 -0
  81. data/lib/chromadb/schemas/chroma-cloud-qwen.json +61 -0
  82. data/lib/chromadb/schemas/chroma-cloud-splade.json +31 -0
  83. data/lib/chromadb/schemas/chroma_bm25.json +37 -0
  84. data/lib/chromadb/search/key.rb +94 -0
  85. data/lib/chromadb/search/limit.rb +41 -0
  86. data/lib/chromadb/search/rank.rb +425 -0
  87. data/lib/chromadb/search/search.rb +73 -0
  88. data/lib/chromadb/search/select.rb +54 -0
  89. data/lib/chromadb/search/where.rb +157 -0
  90. data/lib/chromadb/search.rb +8 -0
  91. data/lib/chromadb/types/results.rb +96 -0
  92. data/lib/chromadb/types/sparse_vector.rb +86 -0
  93. data/lib/chromadb/types/validation.rb +519 -0
  94. data/lib/chromadb/types.rb +13 -0
  95. data/lib/chromadb/version.rb +5 -0
  96. data/lib/chromadb.rb +15 -0
  97. metadata +233 -0
@@ -0,0 +1,573 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Chroma
4
+ class Collection
5
+ attr_reader :id, :name, :metadata, :tenant, :database, :configuration, :schema
6
+
7
+ def initialize(client:, model:, embedding_function: nil, data_loader: nil, schema: nil)
8
+ @client = client
9
+ @embedding_function = embedding_function
10
+ @data_loader = data_loader
11
+
12
+ @id = model["id"]
13
+ @name = model["name"]
14
+ @metadata = model["metadata"]
15
+ @tenant = model["tenant"]
16
+ @database = model["database"]
17
+ @configuration = model["configuration_json"] || model["configuration"]
18
+ @schema = schema
19
+ end
20
+
21
+ def embedding_function
22
+ @embedding_function || @schema&.resolve_embedding_function
23
+ end
24
+
25
+ def count
26
+ path = collection_path
27
+ @client.transport.request(
28
+ :get,
29
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/count",
30
+ )
31
+ end
32
+
33
+ def add(ids:, embeddings: nil, metadatas: nil, documents: nil, uris: nil)
34
+ record_set = {
35
+ ids: ids,
36
+ embeddings: embeddings,
37
+ metadatas: metadatas,
38
+ documents: documents,
39
+ images: nil,
40
+ uris: uris
41
+ }
42
+
43
+ prepared = prepare_records(record_set)
44
+ path = collection_path
45
+
46
+ @client.transport.request(
47
+ :post,
48
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/add",
49
+ json: {
50
+ "ids" => prepared[:ids],
51
+ "embeddings" => prepared[:embeddings],
52
+ "metadatas" => Types::Validation.serialize_metadatas(prepared[:metadatas]),
53
+ "documents" => prepared[:documents],
54
+ "uris" => prepared[:uris]
55
+ },
56
+ )
57
+ nil
58
+ end
59
+
60
+ def get(ids: nil, where: nil, where_document: nil, include: [ "metadatas", "documents" ], limit: nil, offset: nil)
61
+ include = include.dup
62
+ Types::Validation.validate_include(include, disallowed: [ "distances" ])
63
+ Types::Validation.validate_ids(ids) if ids
64
+ Types::Validation.validate_where(where) if where
65
+ Types::Validation.validate_where_document(where_document) if where_document
66
+
67
+ if include.include?("data") && @data_loader.nil?
68
+ raise ArgumentError, "You must set a data loader on the collection if loading from URIs."
69
+ end
70
+
71
+ request_include = include
72
+ request_include << "uris" if include.include?("data") && !include.include?("uris")
73
+
74
+ path = collection_path
75
+ response = @client.transport.request(
76
+ :post,
77
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/get",
78
+ json: {
79
+ "ids" => ids,
80
+ "where" => where,
81
+ "where_document" => where_document,
82
+ "include" => request_include,
83
+ "limit" => limit,
84
+ "offset" => offset
85
+ },
86
+ )
87
+
88
+ metadatas = Types::Validation.deserialize_metadatas(response["metadatas"])
89
+
90
+ data = nil
91
+ if include.include?("data") && @data_loader && response["uris"]
92
+ data = @data_loader.call(response["uris"])
93
+ end
94
+
95
+ uris = include.include?("uris") ? response["uris"] : nil
96
+
97
+ Types::GetResult.new(
98
+ ids: response["ids"],
99
+ embeddings: response["embeddings"],
100
+ metadatas: metadatas,
101
+ documents: response["documents"],
102
+ uris: uris,
103
+ data: data,
104
+ included: include,
105
+ )
106
+ end
107
+
108
+ def peek(limit: 10)
109
+ get(limit: limit)
110
+ end
111
+
112
+ def query(query_embeddings: nil, query_texts: nil, query_uris: nil, ids: nil, n_results: 10,
113
+ where: nil, where_document: nil, include: [ "metadatas", "documents", "distances" ])
114
+ record_set = {
115
+ embeddings: query_embeddings,
116
+ documents: query_texts,
117
+ images: nil,
118
+ uris: query_uris
119
+ }
120
+
121
+ prepared = prepare_query(record_set, include, ids, where, where_document, n_results)
122
+
123
+ path = collection_path
124
+ response = @client.transport.request(
125
+ :post,
126
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/query",
127
+ json: {
128
+ "query_embeddings" => prepared[:embeddings],
129
+ "n_results" => n_results,
130
+ "where" => where,
131
+ "where_document" => where_document,
132
+ "include" => prepared[:include],
133
+ "ids" => ids
134
+ },
135
+ )
136
+
137
+ metadatas = deserialize_metadata_matrix(response["metadatas"])
138
+
139
+ data = nil
140
+ if include.include?("data") && @data_loader && response["uris"]
141
+ data = @data_loader.call(response["uris"])
142
+ end
143
+
144
+ uris = include.include?("uris") ? response["uris"] : nil
145
+
146
+ Types::QueryResult.new(
147
+ ids: response["ids"],
148
+ embeddings: response["embeddings"],
149
+ metadatas: metadatas,
150
+ documents: response["documents"],
151
+ uris: uris,
152
+ data: data,
153
+ distances: response["distances"],
154
+ included: include,
155
+ )
156
+ end
157
+
158
+ def modify(name: nil, metadata: nil, configuration: nil)
159
+ payload = {}
160
+ payload["name"] = name if name
161
+ payload["metadata"] = Types::Validation.serialize_metadata(metadata) if metadata
162
+ config_payload = configuration_to_payload(configuration)
163
+ payload["configuration"] = config_payload unless config_payload.empty?
164
+
165
+ path = collection_path
166
+ @client.transport.request(
167
+ :put,
168
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}",
169
+ json: payload,
170
+ )
171
+ nil
172
+ end
173
+
174
+ def fork(name:)
175
+ path = collection_path
176
+ response = @client.transport.request(
177
+ :post,
178
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/fork",
179
+ json: { "name" => name },
180
+ )
181
+ @client.send(:build_collection, response)
182
+ end
183
+
184
+ def update(ids:, embeddings: nil, metadatas: nil, documents: nil, uris: nil)
185
+ record_set = {
186
+ ids: ids,
187
+ embeddings: embeddings,
188
+ metadatas: metadatas,
189
+ documents: documents,
190
+ images: nil,
191
+ uris: uris
192
+ }
193
+
194
+ prepared = prepare_records(record_set, update: true)
195
+ path = collection_path
196
+
197
+ @client.transport.request(
198
+ :post,
199
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/update",
200
+ json: {
201
+ "ids" => prepared[:ids],
202
+ "embeddings" => prepared[:embeddings],
203
+ "metadatas" => Types::Validation.serialize_metadatas(prepared[:metadatas]),
204
+ "documents" => prepared[:documents],
205
+ "uris" => prepared[:uris]
206
+ },
207
+ )
208
+ nil
209
+ end
210
+
211
+ def upsert(ids:, embeddings: nil, metadatas: nil, documents: nil, uris: nil)
212
+ record_set = {
213
+ ids: ids,
214
+ embeddings: embeddings,
215
+ metadatas: metadatas,
216
+ documents: documents,
217
+ images: nil,
218
+ uris: uris
219
+ }
220
+
221
+ prepared = prepare_records(record_set)
222
+ path = collection_path
223
+
224
+ @client.transport.request(
225
+ :post,
226
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/upsert",
227
+ json: {
228
+ "ids" => prepared[:ids],
229
+ "embeddings" => prepared[:embeddings],
230
+ "metadatas" => Types::Validation.serialize_metadatas(prepared[:metadatas]),
231
+ "documents" => prepared[:documents],
232
+ "uris" => prepared[:uris]
233
+ },
234
+ )
235
+ nil
236
+ end
237
+
238
+ def delete(ids: nil, where: nil, where_document: nil)
239
+ if ids.nil? && where.nil? && where_document.nil?
240
+ raise ArgumentError, "At least one of ids, where, or where_document must be provided"
241
+ end
242
+
243
+ Types::Validation.validate_ids(ids) if ids
244
+ Types::Validation.validate_where(where) if where
245
+ Types::Validation.validate_where_document(where_document) if where_document
246
+
247
+ path = collection_path
248
+ @client.transport.request(
249
+ :post,
250
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/delete",
251
+ json: { "ids" => ids, "where" => where, "where_document" => where_document },
252
+ )
253
+ nil
254
+ end
255
+
256
+ def attach_function(function_id:, name:, output_collection:, params: nil)
257
+ path = collection_path
258
+ response = @client.transport.request(
259
+ :post,
260
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/functions/attach",
261
+ json: {
262
+ "function_id" => function_id,
263
+ "name" => name,
264
+ "output_collection" => output_collection,
265
+ "params" => params
266
+ },
267
+ )
268
+ response
269
+ end
270
+
271
+ def get_attached_function(name:)
272
+ path = collection_path
273
+ @client.transport.request(
274
+ :get,
275
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/functions/#{name}",
276
+ )
277
+ end
278
+
279
+ def detach_function(name:)
280
+ path = collection_path
281
+ @client.transport.request(
282
+ :post,
283
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/attached_functions/#{name}/detach",
284
+ json: {},
285
+ )
286
+ nil
287
+ end
288
+
289
+ def search(searches)
290
+ payloads = (searches.is_a?(Array) ? searches : [ searches ]).map { |item| normalize_search_input(item) }
291
+ embedded_payloads = payloads.map { |payload| embed_search_payload(payload) }
292
+
293
+ body = { "searches" => embedded_payloads }
294
+ path = collection_path
295
+ response = @client.transport.request(
296
+ :post,
297
+ "/tenants/#{path[:tenant]}/databases/#{path[:database]}/collections/#{path[:collection_id]}/search",
298
+ json: body,
299
+ )
300
+
301
+ Types::SearchResult.new(response)
302
+ end
303
+
304
+ private
305
+
306
+ def collection_path
307
+ path = @client.tenant_database_path
308
+ { tenant: path[:tenant], database: path[:database], collection_id: @id }
309
+ end
310
+
311
+ def embed(texts, is_query)
312
+ ef = embedding_function
313
+ raise ArgumentError, "Embedding function must be defined for operations requiring embeddings." unless ef
314
+
315
+ if is_query && ef.respond_to?(:embed_query)
316
+ ef.embed_query(texts)
317
+ else
318
+ ef.call(texts)
319
+ end
320
+ end
321
+
322
+ def sparse_embed(sparse_embedding_function, texts, is_query)
323
+ if is_query && sparse_embedding_function.respond_to?(:embed_query)
324
+ sparse_embedding_function.embed_query(texts)
325
+ else
326
+ sparse_embedding_function.call(texts)
327
+ end
328
+ end
329
+
330
+ def get_sparse_embedding_targets
331
+ return {} unless @schema
332
+
333
+ targets = {}
334
+ @schema.keys.each do |key, value_types|
335
+ sparse_vector = value_types.sparse_vector
336
+ sparse_index = sparse_vector&.sparse_vector_index
337
+ next unless sparse_index&.enabled
338
+
339
+ config = sparse_index.config
340
+ next unless config.embedding_function && config.source_key
341
+
342
+ targets[key] = config
343
+ end
344
+ targets
345
+ end
346
+
347
+ def apply_sparse_embeddings_to_metadatas(metadatas, documents)
348
+ sparse_targets = get_sparse_embedding_targets
349
+ return metadatas if sparse_targets.empty?
350
+
351
+ if metadatas.nil?
352
+ return nil unless documents
353
+ metadatas = Array.new(documents.length) { {} }
354
+ end
355
+
356
+ updated = metadatas.map { |metadata| metadata.nil? ? {} : metadata.dup }
357
+ documents_list = documents ? documents.dup : nil
358
+
359
+ sparse_targets.each do |target_key, config|
360
+ source_key = config.source_key
361
+ embedding_function = config.embedding_function
362
+ next unless source_key && embedding_function
363
+
364
+ inputs = []
365
+ positions = []
366
+
367
+ if source_key == DOCUMENT_KEY
368
+ next unless documents_list
369
+ updated.each_with_index do |metadata, index|
370
+ next if metadata.key?(target_key)
371
+ doc = documents_list[index]
372
+ if doc.is_a?(String)
373
+ inputs << doc
374
+ positions << index
375
+ end
376
+ end
377
+ else
378
+ updated.each_with_index do |metadata, index|
379
+ next if metadata.key?(target_key)
380
+ source_value = metadata[source_key]
381
+ next unless source_value.is_a?(String)
382
+ inputs << source_value
383
+ positions << index
384
+ end
385
+ end
386
+
387
+ next if inputs.empty?
388
+
389
+ sparse_embeddings = sparse_embed(embedding_function, inputs, false)
390
+ if sparse_embeddings.length != positions.length
391
+ raise ArgumentError, "Sparse embedding function returned unexpected number of embeddings."
392
+ end
393
+
394
+ positions.each_with_index do |position, idx|
395
+ updated[position][target_key] = sparse_embeddings[idx]
396
+ end
397
+ end
398
+
399
+ updated.map { |metadata| metadata.empty? ? nil : metadata }
400
+ end
401
+
402
+ def prepare_records(record_set, update: false)
403
+ normalized = Types::Validation.normalize_insert_record_set(
404
+ ids: record_set[:ids],
405
+ embeddings: record_set[:embeddings],
406
+ metadatas: record_set[:metadatas],
407
+ documents: record_set[:documents],
408
+ images: record_set[:images],
409
+ uris: record_set[:uris],
410
+ )
411
+
412
+ Types::Validation.validate_insert_record_set(normalized)
413
+ Types::Validation.validate_record_set_contains_any(normalized, %i[ids])
414
+
415
+ max_batch_size = @client.get_max_batch_size
416
+ if max_batch_size && max_batch_size > 0
417
+ Types::Validation.validate_batch([ normalized[:ids] ], { max_batch_size: max_batch_size })
418
+ end
419
+
420
+ if normalized[:embeddings].nil?
421
+ if update
422
+ if normalized[:documents] || normalized[:images]
423
+ Types::Validation.validate_record_set_for_embedding(normalized, embeddable_fields: %i[documents images])
424
+ normalized[:embeddings] = embed(normalized[:documents] || [], false)
425
+ end
426
+ else
427
+ Types::Validation.validate_record_set_for_embedding(normalized)
428
+ normalized[:embeddings] = embed(normalized[:documents] || [], false)
429
+ end
430
+ end
431
+
432
+ normalized[:metadatas] = apply_sparse_embeddings_to_metadatas(normalized[:metadatas], normalized[:documents])
433
+
434
+ if @client.supports_base64_encoding? && normalized[:embeddings]
435
+ normalized[:embeddings] = Types::Encoding.embeddings_to_base64_strings(normalized[:embeddings])
436
+ end
437
+
438
+ normalized
439
+ end
440
+
441
+ def prepare_query(record_set, include, ids, where, where_document, n_results)
442
+ normalized = Types::Validation.normalize_base_record_set(
443
+ embeddings: record_set[:embeddings],
444
+ documents: record_set[:documents],
445
+ images: record_set[:images],
446
+ uris: record_set[:uris],
447
+ )
448
+
449
+ Types::Validation.validate_base_record_set(normalized)
450
+ Types::Validation.validate_include(include)
451
+ Types::Validation.validate_ids(ids) if ids
452
+ Types::Validation.validate_where(where) if where
453
+ Types::Validation.validate_where_document(where_document) if where_document
454
+ Types::Validation.validate_n_results(n_results) if n_results
455
+
456
+ if normalized[:embeddings].nil?
457
+ Types::Validation.validate_record_set_for_embedding(normalized)
458
+ normalized[:embeddings] = embed(normalized[:documents] || [], true)
459
+ end
460
+
461
+ request_include = include.dup
462
+ request_include << "uris" if include.include?("data") && !include.include?("uris")
463
+
464
+ normalized[:include] = request_include
465
+ normalized
466
+ end
467
+
468
+ def deserialize_metadata_matrix(matrix)
469
+ return nil if matrix.nil?
470
+ matrix.map do |row|
471
+ row&.map { |metadata| Types::Validation.deserialize_metadata(metadata) }
472
+ end
473
+ end
474
+
475
+ def configuration_to_payload(configuration)
476
+ return {} if configuration.nil?
477
+ return configuration.to_h if configuration.respond_to?(:to_h)
478
+ configuration
479
+ end
480
+
481
+ def normalize_search_input(item)
482
+ return item.to_h if item.is_a?(Chroma::Search::Search)
483
+ if item.is_a?(Hash)
484
+ has_limit = item.key?("limit") || item.key?(:limit)
485
+ has_select = item.key?("select") || item.key?(:select)
486
+ if has_limit && has_select
487
+ return stringify_keys(item)
488
+ end
489
+ return Chroma::Search::Search.new(
490
+ where: item[:where] || item["where"],
491
+ rank: item[:rank] || item["rank"],
492
+ limit: item[:limit] || item["limit"],
493
+ select: item[:select] || item["select"],
494
+ ).to_h
495
+ end
496
+
497
+ raise ArgumentError, "Unsupported search input"
498
+ end
499
+
500
+ def embed_search_payload(payload)
501
+ return payload unless payload["rank"]
502
+ embedded_rank = embed_rank_literal(payload["rank"])
503
+ return payload unless embedded_rank.is_a?(Hash)
504
+ payload.merge("rank" => embedded_rank)
505
+ end
506
+
507
+ def embed_rank_literal(rank)
508
+ return rank if rank.nil?
509
+ if rank.is_a?(Array)
510
+ return rank.map { |item| embed_rank_literal(item) }
511
+ end
512
+ return rank unless rank.is_a?(Hash)
513
+
514
+ rank.each_with_object({}) do |(key, value), acc|
515
+ if key == "$knn" && value.is_a?(Hash)
516
+ acc[key] = embed_knn_literal(value)
517
+ else
518
+ acc[key] = embed_rank_literal(value)
519
+ end
520
+ end
521
+ end
522
+
523
+ def embed_knn_literal(knn)
524
+ query_value = knn["query"] || knn[:query]
525
+ return knn if !query_value.is_a?(String)
526
+
527
+ key_value = knn["key"] || knn[:key]
528
+ key = key_value || EMBEDDING_KEY
529
+
530
+ if key == EMBEDDING_KEY
531
+ embeddings = embed([ query_value ], true)
532
+ raise ArgumentError, "Embedding function returned unexpected number of embeddings." unless embeddings.length == 1
533
+ return knn.merge("query" => embeddings[0])
534
+ end
535
+
536
+ raise ArgumentError,
537
+ "Cannot embed string query for key '#{key}': schema is not available. Provide an embedded vector or configure an embedding function." unless @schema
538
+
539
+ value_types = @schema.keys[key]
540
+ raise ArgumentError,
541
+ "Cannot embed string query for key '#{key}': key not found in schema. Provide an embedded vector or configure an embedding function." unless value_types
542
+
543
+ sparse_index = value_types.sparse_vector&.sparse_vector_index
544
+ if sparse_index&.enabled && sparse_index.config.embedding_function
545
+ sparse_embeddings = sparse_embed(sparse_index.config.embedding_function, [ query_value ], true)
546
+ raise ArgumentError, "Sparse embedding function returned unexpected number of embeddings." unless sparse_embeddings.length == 1
547
+ embedded = sparse_embeddings[0]
548
+ return knn.merge("query" => { "indices" => embedded.indices, "values" => embedded.values })
549
+ end
550
+
551
+ vector_index = value_types.float_list&.vector_index
552
+ if vector_index&.enabled && vector_index.config.embedding_function
553
+ ef = vector_index.config.embedding_function
554
+ embeddings = if ef.respond_to?(:embed_query)
555
+ ef.embed_query([ query_value ])
556
+ else
557
+ ef.call([ query_value ])
558
+ end
559
+ raise ArgumentError, "Embedding function returned unexpected number of embeddings." unless embeddings.length == 1
560
+ return knn.merge("query" => embeddings[0])
561
+ end
562
+
563
+ raise ArgumentError,
564
+ "Cannot embed string query for key '#{key}': no embedding function configured. Provide an embedded vector or configure an embedding function."
565
+ end
566
+
567
+ def stringify_keys(hash)
568
+ hash.each_with_object({}) do |(key, value), acc|
569
+ acc[key.to_s] = value
570
+ end
571
+ end
572
+ end
573
+ end