chromadb-experimental 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/lib/chromadb/admin_client.rb +6 -0
- data/lib/chromadb/client.rb +317 -0
- data/lib/chromadb/collection.rb +573 -0
- data/lib/chromadb/embedding_functions/chroma_bm25.rb +459 -0
- data/lib/chromadb/embedding_functions/chroma_cloud_qwen.rb +139 -0
- data/lib/chromadb/embedding_functions/chroma_cloud_splade.rb +121 -0
- data/lib/chromadb/embedding_functions.rb +121 -0
- data/lib/chromadb/errors.rb +120 -0
- data/lib/chromadb/http_client.rb +142 -0
- data/lib/chromadb/openapi/lib/chromadb/api/default_api.rb +2349 -0
- data/lib/chromadb/openapi/lib/chromadb/api_client.rb +392 -0
- data/lib/chromadb/openapi/lib/chromadb/api_error.rb +58 -0
- data/lib/chromadb/openapi/lib/chromadb/configuration.rb +295 -0
- data/lib/chromadb/openapi/lib/chromadb/models/add_collection_records_payload.rb +260 -0
- data/lib/chromadb/openapi/lib/chromadb/models/attach_function_request.rb +250 -0
- data/lib/chromadb/openapi/lib/chromadb/models/attach_function_response.rb +235 -0
- data/lib/chromadb/openapi/lib/chromadb/models/attached_function_api_response.rb +361 -0
- data/lib/chromadb/openapi/lib/chromadb/models/attached_function_info.rb +240 -0
- data/lib/chromadb/openapi/lib/chromadb/models/bool_inverted_index_type.rb +229 -0
- data/lib/chromadb/openapi/lib/chromadb/models/bool_value_type.rb +221 -0
- data/lib/chromadb/openapi/lib/chromadb/models/checklist_response.rb +245 -0
- data/lib/chromadb/openapi/lib/chromadb/models/collection.rb +315 -0
- data/lib/chromadb/openapi/lib/chromadb/models/collection_configuration.rb +240 -0
- data/lib/chromadb/openapi/lib/chromadb/models/create_collection_payload.rb +260 -0
- data/lib/chromadb/openapi/lib/chromadb/models/create_database_payload.rb +220 -0
- data/lib/chromadb/openapi/lib/chromadb/models/create_tenant_payload.rb +220 -0
- data/lib/chromadb/openapi/lib/chromadb/models/database.rb +240 -0
- data/lib/chromadb/openapi/lib/chromadb/models/detach_function_request.rb +221 -0
- data/lib/chromadb/openapi/lib/chromadb/models/detach_function_response.rb +220 -0
- data/lib/chromadb/openapi/lib/chromadb/models/embedding_function_new_configuration.rb +230 -0
- data/lib/chromadb/openapi/lib/chromadb/models/error_response.rb +230 -0
- data/lib/chromadb/openapi/lib/chromadb/models/float_inverted_index_type.rb +229 -0
- data/lib/chromadb/openapi/lib/chromadb/models/float_list_value_type.rb +221 -0
- data/lib/chromadb/openapi/lib/chromadb/models/float_value_type.rb +221 -0
- data/lib/chromadb/openapi/lib/chromadb/models/fork_collection_payload.rb +220 -0
- data/lib/chromadb/openapi/lib/chromadb/models/fts_index_type.rb +229 -0
- data/lib/chromadb/openapi/lib/chromadb/models/get_attached_function_response.rb +224 -0
- data/lib/chromadb/openapi/lib/chromadb/models/get_response.rb +270 -0
- data/lib/chromadb/openapi/lib/chromadb/models/get_tenant_response.rb +230 -0
- data/lib/chromadb/openapi/lib/chromadb/models/get_user_identity_response.rb +246 -0
- data/lib/chromadb/openapi/lib/chromadb/models/heartbeat_response.rb +235 -0
- data/lib/chromadb/openapi/lib/chromadb/models/hnsw_configuration.rb +330 -0
- data/lib/chromadb/openapi/lib/chromadb/models/hnsw_index_config.rb +371 -0
- data/lib/chromadb/openapi/lib/chromadb/models/include.rb +210 -0
- data/lib/chromadb/openapi/lib/chromadb/models/int_inverted_index_type.rb +229 -0
- data/lib/chromadb/openapi/lib/chromadb/models/int_value_type.rb +221 -0
- data/lib/chromadb/openapi/lib/chromadb/models/query_response.rb +280 -0
- data/lib/chromadb/openapi/lib/chromadb/models/raw_where_fields.rb +230 -0
- data/lib/chromadb/openapi/lib/chromadb/models/schema.rb +258 -0
- data/lib/chromadb/openapi/lib/chromadb/models/search_payload.rb +256 -0
- data/lib/chromadb/openapi/lib/chromadb/models/search_payload_filter.rb +230 -0
- data/lib/chromadb/openapi/lib/chromadb/models/search_payload_group_by.rb +230 -0
- data/lib/chromadb/openapi/lib/chromadb/models/search_payload_limit.rb +230 -0
- data/lib/chromadb/openapi/lib/chromadb/models/search_payload_select.rb +220 -0
- data/lib/chromadb/openapi/lib/chromadb/models/search_request_payload.rb +220 -0
- data/lib/chromadb/openapi/lib/chromadb/models/search_response.rb +270 -0
- data/lib/chromadb/openapi/lib/chromadb/models/space.rb +210 -0
- data/lib/chromadb/openapi/lib/chromadb/models/spann_configuration.rb +420 -0
- data/lib/chromadb/openapi/lib/chromadb/models/spann_index_config.rb +536 -0
- data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector.rb +244 -0
- data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector_index_config.rb +242 -0
- data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector_index_type.rb +234 -0
- data/lib/chromadb/openapi/lib/chromadb/models/sparse_vector_value_type.rb +221 -0
- data/lib/chromadb/openapi/lib/chromadb/models/string_inverted_index_type.rb +229 -0
- data/lib/chromadb/openapi/lib/chromadb/models/string_value_type.rb +231 -0
- data/lib/chromadb/openapi/lib/chromadb/models/update_collection_configuration.rb +240 -0
- data/lib/chromadb/openapi/lib/chromadb/models/update_collection_payload.rb +240 -0
- data/lib/chromadb/openapi/lib/chromadb/models/update_collection_records_payload.rb +260 -0
- data/lib/chromadb/openapi/lib/chromadb/models/update_hnsw_configuration.rb +345 -0
- data/lib/chromadb/openapi/lib/chromadb/models/update_spann_configuration.rb +260 -0
- data/lib/chromadb/openapi/lib/chromadb/models/update_tenant_payload.rb +220 -0
- data/lib/chromadb/openapi/lib/chromadb/models/upsert_collection_records_payload.rb +260 -0
- data/lib/chromadb/openapi/lib/chromadb/models/value_types.rb +271 -0
- data/lib/chromadb/openapi/lib/chromadb/models/vector_index_config.rb +261 -0
- data/lib/chromadb/openapi/lib/chromadb/models/vector_index_type.rb +234 -0
- data/lib/chromadb/openapi/lib/chromadb/version.rb +15 -0
- data/lib/chromadb/openapi/lib/chromadb.rb +102 -0
- data/lib/chromadb/openapi.rb +6 -0
- data/lib/chromadb/schema.rb +744 -0
- data/lib/chromadb/schemas/chroma-cloud-qwen.json +61 -0
- data/lib/chromadb/schemas/chroma-cloud-splade.json +31 -0
- data/lib/chromadb/schemas/chroma_bm25.json +37 -0
- data/lib/chromadb/search/key.rb +94 -0
- data/lib/chromadb/search/limit.rb +41 -0
- data/lib/chromadb/search/rank.rb +425 -0
- data/lib/chromadb/search/search.rb +73 -0
- data/lib/chromadb/search/select.rb +54 -0
- data/lib/chromadb/search/where.rb +157 -0
- data/lib/chromadb/search.rb +8 -0
- data/lib/chromadb/types/results.rb +96 -0
- data/lib/chromadb/types/sparse_vector.rb +86 -0
- data/lib/chromadb/types/validation.rb +519 -0
- data/lib/chromadb/types.rb +13 -0
- data/lib/chromadb/version.rb +5 -0
- data/lib/chromadb.rb +15 -0
- metadata +233 -0
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "set"
|
|
4
|
+
module Chroma
|
|
5
|
+
module Search
|
|
6
|
+
class Search
|
|
7
|
+
attr_reader :where_clause, :rank_expression, :limit_config, :select_config
|
|
8
|
+
|
|
9
|
+
def initialize(where: nil, rank: nil, limit: nil, select: nil)
|
|
10
|
+
@where_clause = WhereExpression.from(where) if where
|
|
11
|
+
@rank_expression = RankExpression.from(rank) if rank
|
|
12
|
+
@limit_config = Limit.from(limit)
|
|
13
|
+
@select_config = Select.from(select)
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def where(where = nil)
|
|
17
|
+
clone_with(where: WhereExpression.from(where))
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def rank(rank = nil)
|
|
21
|
+
clone_with(rank: RankExpression.from(rank))
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def limit(limit = nil, offset = nil)
|
|
25
|
+
if limit.is_a?(Numeric)
|
|
26
|
+
clone_with(limit: Limit.from(limit.to_i, offset))
|
|
27
|
+
else
|
|
28
|
+
clone_with(limit: Limit.from(limit))
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def select(*keys)
|
|
33
|
+
if keys.length == 1 && (keys[0].is_a?(Array) || keys[0].is_a?(Set))
|
|
34
|
+
return clone_with(select: Select.from(keys[0]))
|
|
35
|
+
end
|
|
36
|
+
if keys.length == 1 && keys[0].is_a?(Select)
|
|
37
|
+
return clone_with(select: Select.from(keys[0]))
|
|
38
|
+
end
|
|
39
|
+
if keys.length == 1 && keys[0].is_a?(Hash) && (keys[0].key?(:keys) || keys[0].key?("keys"))
|
|
40
|
+
return clone_with(select: Select.from(keys[0]))
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
clone_with(select: Select.from(keys))
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def select_all
|
|
47
|
+
clone_with(select: Select.all)
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def to_h
|
|
51
|
+
payload = {
|
|
52
|
+
"limit" => @limit_config.to_h,
|
|
53
|
+
"select" => @select_config.to_h
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
payload["filter"] = @where_clause.to_h if @where_clause
|
|
57
|
+
payload["rank"] = @rank_expression.to_h if @rank_expression
|
|
58
|
+
payload
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
private
|
|
62
|
+
|
|
63
|
+
def clone_with(where: @where_clause, rank: @rank_expression, limit: @limit_config, select: @select_config)
|
|
64
|
+
instance = self.class.allocate
|
|
65
|
+
instance.instance_variable_set(:@where_clause, where)
|
|
66
|
+
instance.instance_variable_set(:@rank_expression, rank)
|
|
67
|
+
instance.instance_variable_set(:@limit_config, limit)
|
|
68
|
+
instance.instance_variable_set(:@select_config, select)
|
|
69
|
+
instance
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Chroma
|
|
4
|
+
module Search
|
|
5
|
+
class Select
|
|
6
|
+
def initialize(keys = [])
|
|
7
|
+
unique = []
|
|
8
|
+
seen = {}
|
|
9
|
+
Array(keys).each do |key|
|
|
10
|
+
normalized = key.respond_to?(:name) ? key.name : key
|
|
11
|
+
unless normalized.is_a?(String)
|
|
12
|
+
raise TypeError, "Select keys must be strings or Key instances"
|
|
13
|
+
end
|
|
14
|
+
next if seen[normalized]
|
|
15
|
+
seen[normalized] = true
|
|
16
|
+
unique << normalized
|
|
17
|
+
end
|
|
18
|
+
@keys = unique
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def self.from(input)
|
|
22
|
+
return Select.new(input.values) if input.is_a?(Select)
|
|
23
|
+
return Select.new if input.nil?
|
|
24
|
+
|
|
25
|
+
if input.is_a?(Hash) && (input.key?(:keys) || input.key?("keys"))
|
|
26
|
+
keys = input[:keys] || input["keys"] || []
|
|
27
|
+
return Select.new(keys)
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
if input.is_a?(String)
|
|
31
|
+
return Select.new([ input ])
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
if input.respond_to?(:each)
|
|
35
|
+
return Select.new(input)
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
raise TypeError, "Unsupported select input"
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def self.all
|
|
42
|
+
Select.new([ K::DOCUMENT, K::EMBEDDING, K::METADATA, K::SCORE ])
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def values
|
|
46
|
+
@keys.dup
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def to_h
|
|
50
|
+
{ "keys" => values }
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Chroma
|
|
4
|
+
module Search
|
|
5
|
+
class WhereExpression
|
|
6
|
+
def and(other)
|
|
7
|
+
target = self.class.from(other)
|
|
8
|
+
return self unless target
|
|
9
|
+
AndWhere.combine(self, target)
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def or(other)
|
|
13
|
+
target = self.class.from(other)
|
|
14
|
+
return self unless target
|
|
15
|
+
OrWhere.combine(self, target)
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def to_h
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def self.from(input)
|
|
23
|
+
return input if input.is_a?(WhereExpression)
|
|
24
|
+
return nil if input.nil?
|
|
25
|
+
unless input.is_a?(Hash)
|
|
26
|
+
raise TypeError, "Where input must be a WhereExpression or Hash"
|
|
27
|
+
end
|
|
28
|
+
parse_where_hash(input)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def self.create_comparison(key, operator, value)
|
|
32
|
+
ComparisonWhere.new(key, operator, value)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def self.parse_where_hash(data)
|
|
36
|
+
if data.key?("$and")
|
|
37
|
+
raise ArgumentError, "$and cannot be combined with other keys" if data.length != 1
|
|
38
|
+
raw = data["$and"]
|
|
39
|
+
unless raw.is_a?(Array) && !raw.empty?
|
|
40
|
+
raise TypeError, "$and must be a non-empty array"
|
|
41
|
+
end
|
|
42
|
+
conditions = raw.map.with_index do |item, index|
|
|
43
|
+
expr = from(item)
|
|
44
|
+
raise TypeError, "Invalid where clause at index #{index}" if expr.nil?
|
|
45
|
+
expr
|
|
46
|
+
end
|
|
47
|
+
return conditions[0] if conditions.length == 1
|
|
48
|
+
|
|
49
|
+
return conditions.drop(1).reduce(conditions[0]) { |acc, cond| AndWhere.combine(acc, cond) }
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
if data.key?("$or")
|
|
53
|
+
raise ArgumentError, "$or cannot be combined with other keys" if data.length != 1
|
|
54
|
+
raw = data["$or"]
|
|
55
|
+
unless raw.is_a?(Array) && !raw.empty?
|
|
56
|
+
raise TypeError, "$or must be a non-empty array"
|
|
57
|
+
end
|
|
58
|
+
conditions = raw.map.with_index do |item, index|
|
|
59
|
+
expr = from(item)
|
|
60
|
+
raise TypeError, "Invalid where clause at index #{index}" if expr.nil?
|
|
61
|
+
expr
|
|
62
|
+
end
|
|
63
|
+
return conditions[0] if conditions.length == 1
|
|
64
|
+
|
|
65
|
+
return conditions.drop(1).reduce(conditions[0]) { |acc, cond| OrWhere.combine(acc, cond) }
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
entries = data.to_a
|
|
69
|
+
if entries.length != 1
|
|
70
|
+
raise ArgumentError, "Where hash must contain exactly one field"
|
|
71
|
+
end
|
|
72
|
+
field, value = entries[0]
|
|
73
|
+
unless value.is_a?(Hash)
|
|
74
|
+
return ComparisonWhere.new(field, "$eq", value)
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
operator_entries = value.to_a
|
|
78
|
+
if operator_entries.length != 1
|
|
79
|
+
raise ArgumentError, "Operator hash for field '#{field}' must contain exactly one operator"
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
operator, operand = operator_entries[0]
|
|
83
|
+
unless %w[$eq $ne $gt $gte $lt $lte $in $nin $contains $not_contains $regex $not_regex].include?(operator)
|
|
84
|
+
raise ArgumentError, "Unsupported where operator: #{operator}"
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
ComparisonWhere.new(field, operator, operand)
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
class AndWhere < WhereExpression
|
|
92
|
+
def initialize(conditions)
|
|
93
|
+
@conditions = conditions
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def to_h
|
|
97
|
+
{ "$and" => @conditions.map(&:to_h) }
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def operands
|
|
101
|
+
@conditions.dup
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
def self.combine(left, right)
|
|
105
|
+
flattened = []
|
|
106
|
+
[ left, right ].each do |expr|
|
|
107
|
+
if expr.is_a?(AndWhere)
|
|
108
|
+
flattened.concat(expr.operands)
|
|
109
|
+
else
|
|
110
|
+
flattened << expr
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
return flattened[0] if flattened.length == 1
|
|
114
|
+
AndWhere.new(flattened)
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
class OrWhere < WhereExpression
|
|
119
|
+
def initialize(conditions)
|
|
120
|
+
@conditions = conditions
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
def to_h
|
|
124
|
+
{ "$or" => @conditions.map(&:to_h) }
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def operands
|
|
128
|
+
@conditions.dup
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
def self.combine(left, right)
|
|
132
|
+
flattened = []
|
|
133
|
+
[ left, right ].each do |expr|
|
|
134
|
+
if expr.is_a?(OrWhere)
|
|
135
|
+
flattened.concat(expr.operands)
|
|
136
|
+
else
|
|
137
|
+
flattened << expr
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
return flattened[0] if flattened.length == 1
|
|
141
|
+
OrWhere.new(flattened)
|
|
142
|
+
end
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
class ComparisonWhere < WhereExpression
|
|
146
|
+
def initialize(key, operator, value)
|
|
147
|
+
@key = key
|
|
148
|
+
@operator = operator
|
|
149
|
+
@value = value
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
def to_h
|
|
153
|
+
{ @key => { @operator => @value } }
|
|
154
|
+
end
|
|
155
|
+
end
|
|
156
|
+
end
|
|
157
|
+
end
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Chroma
|
|
4
|
+
module Types
|
|
5
|
+
class GetResult
|
|
6
|
+
attr_reader :ids, :embeddings, :metadatas, :documents, :uris, :data, :included
|
|
7
|
+
|
|
8
|
+
def initialize(ids:, embeddings:, metadatas:, documents:, uris:, data:, included: nil)
|
|
9
|
+
@ids = ids
|
|
10
|
+
@embeddings = embeddings
|
|
11
|
+
@metadatas = metadatas
|
|
12
|
+
@documents = documents
|
|
13
|
+
@uris = uris
|
|
14
|
+
@data = data
|
|
15
|
+
@included = included
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def to_h
|
|
19
|
+
{
|
|
20
|
+
ids: @ids,
|
|
21
|
+
embeddings: @embeddings,
|
|
22
|
+
metadatas: @metadatas,
|
|
23
|
+
documents: @documents,
|
|
24
|
+
uris: @uris,
|
|
25
|
+
data: @data,
|
|
26
|
+
included: @included
|
|
27
|
+
}
|
|
28
|
+
end
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
class QueryResult < GetResult
|
|
32
|
+
attr_reader :distances
|
|
33
|
+
|
|
34
|
+
def initialize(ids:, embeddings:, metadatas:, documents:, uris:, data:, distances:, included: nil)
|
|
35
|
+
super(ids: ids, embeddings: embeddings, metadatas: metadatas, documents: documents, uris: uris, data: data, included: included)
|
|
36
|
+
@distances = distances
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def to_h
|
|
40
|
+
super.merge(distances: @distances)
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
class SearchResult
|
|
45
|
+
attr_reader :ids, :documents, :embeddings, :metadatas, :scores, :select
|
|
46
|
+
|
|
47
|
+
def initialize(response)
|
|
48
|
+
@ids = response.fetch("ids")
|
|
49
|
+
payload_count = @ids.length
|
|
50
|
+
|
|
51
|
+
@documents = normalize_payload_array(response["documents"], payload_count)
|
|
52
|
+
@embeddings = normalize_payload_array(response["embeddings"], payload_count)
|
|
53
|
+
raw_metadatas = normalize_payload_array(response["metadatas"], payload_count)
|
|
54
|
+
@metadatas = raw_metadatas.map { |payload| payload ? Types::Validation.deserialize_metadatas(payload) : nil }
|
|
55
|
+
@scores = normalize_payload_array(response["scores"], payload_count)
|
|
56
|
+
@select = response["select"] || []
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def rows
|
|
60
|
+
results = []
|
|
61
|
+
@ids.each_with_index do |id_batch, batch_index|
|
|
62
|
+
docs = @documents[batch_index] || []
|
|
63
|
+
embeds = @embeddings[batch_index] || []
|
|
64
|
+
metas = @metadatas[batch_index] || []
|
|
65
|
+
scores = @scores[batch_index] || []
|
|
66
|
+
|
|
67
|
+
batch_rows = id_batch.map.with_index do |id, row_index|
|
|
68
|
+
row = { id: id }
|
|
69
|
+
doc = docs[row_index]
|
|
70
|
+
row[:document] = doc unless doc.nil?
|
|
71
|
+
emb = embeds[row_index]
|
|
72
|
+
row[:embedding] = emb unless emb.nil?
|
|
73
|
+
meta = metas[row_index]
|
|
74
|
+
row[:metadata] = meta unless meta.nil?
|
|
75
|
+
score = scores[row_index]
|
|
76
|
+
row[:score] = score unless score.nil?
|
|
77
|
+
row
|
|
78
|
+
end
|
|
79
|
+
results << batch_rows
|
|
80
|
+
end
|
|
81
|
+
results
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
private
|
|
85
|
+
|
|
86
|
+
def normalize_payload_array(payload, count)
|
|
87
|
+
return Array.new(count) { nil } if payload.nil?
|
|
88
|
+
return payload.map { |item| item ? item.dup : nil } if payload.length == count
|
|
89
|
+
|
|
90
|
+
result = payload.map { |item| item ? item.dup : nil }
|
|
91
|
+
result.fill(nil, result.length...count)
|
|
92
|
+
result
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
end
|
|
96
|
+
end
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Chroma
|
|
4
|
+
module Types
|
|
5
|
+
class SparseVector
|
|
6
|
+
attr_reader :indices, :values, :labels
|
|
7
|
+
|
|
8
|
+
def initialize(indices:, values:, labels: nil)
|
|
9
|
+
@indices = indices
|
|
10
|
+
@values = values
|
|
11
|
+
@labels = labels
|
|
12
|
+
validate!
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def to_h
|
|
16
|
+
result = {
|
|
17
|
+
TYPE_KEY => SPARSE_VECTOR_TYPE_VALUE,
|
|
18
|
+
"indices" => @indices,
|
|
19
|
+
"values" => @values
|
|
20
|
+
}
|
|
21
|
+
result["tokens"] = @labels if @labels
|
|
22
|
+
result
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def self.from_h(data)
|
|
26
|
+
unless data.is_a?(Hash) && data[TYPE_KEY] == SPARSE_VECTOR_TYPE_VALUE
|
|
27
|
+
raise ArgumentError,
|
|
28
|
+
"Expected #{TYPE_KEY}='#{SPARSE_VECTOR_TYPE_VALUE}', got #{data[TYPE_KEY].inspect}"
|
|
29
|
+
end
|
|
30
|
+
new(indices: data.fetch("indices"), values: data.fetch("values"), labels: data["tokens"])
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
private
|
|
34
|
+
|
|
35
|
+
def validate!
|
|
36
|
+
unless @indices.is_a?(Array)
|
|
37
|
+
raise ArgumentError, "Expected SparseVector indices to be an Array, got #{@indices.class}"
|
|
38
|
+
end
|
|
39
|
+
unless @values.is_a?(Array)
|
|
40
|
+
raise ArgumentError, "Expected SparseVector values to be an Array, got #{@values.class}"
|
|
41
|
+
end
|
|
42
|
+
if @indices.length != @values.length
|
|
43
|
+
raise ArgumentError,
|
|
44
|
+
"SparseVector indices and values must have the same length, got #{@indices.length} indices and #{@values.length} values"
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
if @labels
|
|
48
|
+
unless @labels.is_a?(Array)
|
|
49
|
+
raise ArgumentError, "Expected SparseVector labels to be an Array, got #{@labels.class}"
|
|
50
|
+
end
|
|
51
|
+
if @labels.length != @indices.length
|
|
52
|
+
raise ArgumentError,
|
|
53
|
+
"SparseVector labels must match indices length, got #{@labels.length} labels and #{@indices.length} indices"
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
@indices.each_with_index do |idx, i|
|
|
58
|
+
unless idx.is_a?(Integer)
|
|
59
|
+
raise ArgumentError,
|
|
60
|
+
"SparseVector indices must be integers, got #{idx.inspect} at position #{i}"
|
|
61
|
+
end
|
|
62
|
+
if idx.negative?
|
|
63
|
+
raise ArgumentError,
|
|
64
|
+
"SparseVector indices must be non-negative, got #{idx} at position #{i}"
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
@values.each_with_index do |val, i|
|
|
69
|
+
unless val.is_a?(Numeric)
|
|
70
|
+
raise ArgumentError,
|
|
71
|
+
"SparseVector values must be numeric, got #{val.inspect} at position #{i}"
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
if @indices.length > 1
|
|
76
|
+
@indices.each_cons(2).with_index do |(prev, curr), i|
|
|
77
|
+
if curr <= prev
|
|
78
|
+
raise ArgumentError,
|
|
79
|
+
"SparseVector indices must be strictly ascending, found indices[#{i + 1}]=#{curr} <= indices[#{i}]=#{prev}"
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
end
|