roseflow 0.0.1 → 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.standard.yml +3 -1
- data/CHANGELOG.md +2 -2
- data/Gemfile +2 -0
- data/examples/github-repo-chat/lib/actions/clone_and_load_repository.rb +52 -0
- data/examples/github-repo-chat/lib/actions/create_prompt.rb +15 -0
- data/examples/github-repo-chat/lib/actions/embed_repository.rb +35 -0
- data/examples/github-repo-chat/lib/actions/initialize_vector_store.rb +40 -0
- data/examples/github-repo-chat/lib/actions/interact_with_model.rb +29 -0
- data/examples/github-repo-chat/lib/actions/load_documents_to_database.rb +0 -0
- data/examples/github-repo-chat/lib/actions/split_files_to_documents.rb +55 -0
- data/examples/github-repo-chat/lib/document_database.rb +0 -0
- data/examples/github-repo-chat/lib/github_chat_prompt.rb +24 -0
- data/examples/github-repo-chat/lib/github_repository_chat.rb +12 -0
- data/examples/github-repo-chat/lib/interactions/ask_llm.rb +31 -0
- data/examples/github-repo-chat/lib/interactions/github_repository_chat.rb +36 -0
- data/examples/github-repo-chat/lib/interactions/load_files_to_document_database.rb +18 -0
- data/examples/github-repo-chat/lib/interactions/load_repository.rb +20 -0
- data/examples/github-repo-chat/lib/interactions/prepare_vector_store.rb +21 -0
- data/examples/github-repo-chat/lib/repository.rb +9 -0
- data/examples/github-repo-chat/lib/repository_file.rb +31 -0
- data/examples/github-repo-chat/spec/actions/clone_and_load_repository_spec.rb +28 -0
- data/examples/github-repo-chat/spec/actions/embed_repository_spec.rb +24 -0
- data/examples/github-repo-chat/spec/actions/initialize_vector_store_spec.rb +20 -0
- data/examples/github-repo-chat/spec/actions/load_files_to_document_database_spec.rb +23 -0
- data/examples/github-repo-chat/spec/fixtures/ulid-ruby.zip +0 -0
- data/examples/github-repo-chat/spec/github_repository_chat_spec.rb +16 -0
- data/examples/github-repo-chat/spec/interactions/prepare_vector_store_spec.rb +4 -0
- data/examples/github-repo-chat/spec/spec_helper.rb +12 -0
- data/lib/roseflow/action.rb +13 -0
- data/lib/roseflow/actions/ai/resolve_model.rb +27 -0
- data/lib/roseflow/actions/ai/resolve_provider.rb +31 -0
- data/lib/roseflow/ai/model.rb +19 -0
- data/lib/roseflow/ai/provider.rb +30 -0
- data/lib/roseflow/chat/dialogue.rb +80 -0
- data/lib/roseflow/chat/exchange.rb +12 -0
- data/lib/roseflow/chat/message.rb +39 -0
- data/lib/roseflow/chat/personality.rb +10 -0
- data/lib/roseflow/embeddings/embedding.rb +26 -0
- data/lib/roseflow/finite_machine.rb +298 -0
- data/lib/roseflow/interaction/with_http_api.rb +10 -0
- data/lib/roseflow/interaction.rb +14 -0
- data/lib/roseflow/interaction_context.rb +10 -0
- data/lib/roseflow/interactions/ai/initialize_llm.rb +26 -0
- data/lib/roseflow/primitives/vector.rb +19 -0
- data/lib/roseflow/prompt.rb +17 -0
- data/lib/roseflow/text/completion.rb +16 -0
- data/lib/roseflow/text/recursive_character_splitter.rb +43 -0
- data/lib/roseflow/text/sentence_splitter.rb +42 -0
- data/lib/roseflow/text/splitter.rb +18 -0
- data/lib/roseflow/text/tokenized_text.rb +20 -0
- data/lib/roseflow/text/word_splitter.rb +14 -0
- data/lib/roseflow/tokenizer.rb +13 -0
- data/lib/roseflow/types.rb +9 -0
- data/lib/roseflow/vector_stores/base.rb +39 -0
- data/lib/roseflow/vector_stores/hnsw.proto +18 -0
- data/lib/roseflow/vector_stores/hnsw_memory_store.rb +442 -0
- data/lib/roseflow/vector_stores/hnsw_pb.rb +27 -0
- data/lib/roseflow/vector_stores/type/vector.rb +38 -0
- data/lib/roseflow/vector_stores/vector.rb +19 -0
- data/lib/roseflow/version.rb +12 -1
- data/lib/roseflow.rb +10 -1
- data/roseflow.gemspec +53 -0
- metadata +274 -7
@@ -0,0 +1,43 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "roseflow/text/splitter"
|
4
|
+
|
5
|
+
module Roseflow
|
6
|
+
module Text
|
7
|
+
class RecursiveCharacterSplitter < Splitter
|
8
|
+
SEPARATORS = ["\n\n", "\n", " ", ""]
|
9
|
+
|
10
|
+
def initialize(separators = nil, **kwargs)
|
11
|
+
super(**kwargs)
|
12
|
+
@separators = separators || SEPARATORS
|
13
|
+
end
|
14
|
+
|
15
|
+
attr_reader :chunk_size, :chunk_overlap
|
16
|
+
|
17
|
+
def split(text)
|
18
|
+
segments = text.split(find_separator(text))
|
19
|
+
current_size = 0
|
20
|
+
results = [[]]
|
21
|
+
|
22
|
+
segments.each do |segment|
|
23
|
+
if current_size + segment.size > chunk_size
|
24
|
+
overlap = [results.last.last(chunk_overlap), segment].flatten
|
25
|
+
current_size = overlap.sum(&:size) + chunk_overlap
|
26
|
+
results << overlap
|
27
|
+
else
|
28
|
+
current_size += segment.size + results.last.size
|
29
|
+
results.last << segment
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
results.map { |r| r.join(" ") }
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def find_separator(text)
|
39
|
+
@separators.find { |separator| text.include?(separator) } || @separators.last
|
40
|
+
end
|
41
|
+
end # RecursiveCharacterSplitter
|
42
|
+
end # Text
|
43
|
+
end # Roseflow
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "pragmatic_segmenter"
|
4
|
+
require "roseflow/text/splitter"
|
5
|
+
|
6
|
+
module Roseflow
|
7
|
+
module Text
|
8
|
+
class SentenceSplitter < Splitter
|
9
|
+
def initialize(language: "en", **kwargs)
|
10
|
+
super(**kwargs)
|
11
|
+
@language = language
|
12
|
+
end
|
13
|
+
|
14
|
+
attr_reader :chunk_size, :chunk_overlap
|
15
|
+
|
16
|
+
def segmenter(text)
|
17
|
+
@segmenter ||= PragmaticSegmenter::Segmenter.new(text: text, language: @language)
|
18
|
+
end
|
19
|
+
|
20
|
+
def split(text)
|
21
|
+
segments = segmenter(text).segment
|
22
|
+
current_size = 0
|
23
|
+
results = [[]]
|
24
|
+
|
25
|
+
segments.each do |segment|
|
26
|
+
if current_size + segment.size > chunk_size
|
27
|
+
overlap = [results.last.last(chunk_overlap), segment].flatten
|
28
|
+
current_size = overlap.sum(&:size) + chunk_overlap
|
29
|
+
results << overlap
|
30
|
+
else
|
31
|
+
current_size += segment.size + results.last.size
|
32
|
+
results.last << segment
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
results.map { |r| r.join(" ") }
|
37
|
+
end
|
38
|
+
|
39
|
+
private
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Roseflow
|
4
|
+
module Text
|
5
|
+
class Splitter
|
6
|
+
def initialize(chunk_size:, chunk_overlap:)
|
7
|
+
raise ArgumentError, "chunk overlap cannot exceed chunk size" if chunk_overlap > chunk_size
|
8
|
+
|
9
|
+
@chunk_size = chunk_size
|
10
|
+
@chunk_overlap = chunk_overlap
|
11
|
+
end
|
12
|
+
|
13
|
+
def split(text)
|
14
|
+
raise NotImplementedError, "this class must be extended and the #split method implemented"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "dry-struct"
|
4
|
+
|
5
|
+
module Roseflow
|
6
|
+
module Text
|
7
|
+
class TokenizedText < Dry::Struct
|
8
|
+
attribute :text, Types::String
|
9
|
+
attribute :tokens, Types::Array.of(Types::Integer)
|
10
|
+
|
11
|
+
def token_count
|
12
|
+
tokens.count
|
13
|
+
end
|
14
|
+
|
15
|
+
def to_s
|
16
|
+
text
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "roseflow/text/recursive_character_splitter"
|
4
|
+
|
5
|
+
module Roseflow
|
6
|
+
module Text
|
7
|
+
class WordSplitter < RecursiveCharacterSplitter
|
8
|
+
def initialize(separators = [" "], **kwargs)
|
9
|
+
super(**kwargs)
|
10
|
+
@separators = separators
|
11
|
+
end
|
12
|
+
end # WordSplitter
|
13
|
+
end # Text
|
14
|
+
end # Roseflow
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Roseflow
|
4
|
+
class Tokenizer
|
5
|
+
def encode(input)
|
6
|
+
raise NotImplementedError, "this class must be extended and the #encode method implemented"
|
7
|
+
end
|
8
|
+
|
9
|
+
def decode(input)
|
10
|
+
raise NotImplementedError, "this class must be extended and the #decode method implemented"
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Roseflow
|
4
|
+
module VectorStores
|
5
|
+
class Base
|
6
|
+
def has_embeddings?(klass)
|
7
|
+
false
|
8
|
+
end
|
9
|
+
|
10
|
+
def list_vectors(namespace = nil)
|
11
|
+
raise NotImplementedError, "You must implement the #list_vectors method in your vector store"
|
12
|
+
end
|
13
|
+
|
14
|
+
def build_vector(id, attrs)
|
15
|
+
raise NotImplementedError, "You must implement the #build_vector method in your vector store"
|
16
|
+
end
|
17
|
+
|
18
|
+
def create_vector(vector)
|
19
|
+
raise NotImplementedError, "You must implement the #create_vector method in your vector store"
|
20
|
+
end
|
21
|
+
|
22
|
+
def delete_vector(name)
|
23
|
+
raise NotImplementedError, "You must implement the #delete_vector method in your vector store"
|
24
|
+
end
|
25
|
+
|
26
|
+
def update_vector(vector)
|
27
|
+
raise NotImplementedError, "You must implement the #update_vector method in your vector store"
|
28
|
+
end
|
29
|
+
|
30
|
+
def query(query)
|
31
|
+
raise NotImplementedError, "You must implement the #query method in your vector store"
|
32
|
+
end
|
33
|
+
|
34
|
+
def find(query)
|
35
|
+
raise NotImplementedError, "You must implement the #find method in your vector store"
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
syntax = "proto3";
|
2
|
+
|
3
|
+
message HNSWGraphNode {
|
4
|
+
string id = 1;
|
5
|
+
repeated float vector = 2;
|
6
|
+
int32 level = 3;
|
7
|
+
repeated string neighbors = 4;
|
8
|
+
}
|
9
|
+
|
10
|
+
message HNSWGraph {
|
11
|
+
string entrypoint_id = 1;
|
12
|
+
int32 max_level = 2;
|
13
|
+
string similarity_metric = 3;
|
14
|
+
int32 dimensions = 4;
|
15
|
+
int32 m = 5;
|
16
|
+
int32 ef = 6;
|
17
|
+
repeated HNSWGraphNode nodes = 7;
|
18
|
+
}
|
@@ -0,0 +1,442 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "ulid"
|
4
|
+
require "roseflow/vector_stores/hnsw_pb"
|
5
|
+
|
6
|
+
module Roseflow
|
7
|
+
module VectorStores
|
8
|
+
UnsupportedSimilarityMetricError = Class.new(StandardError)
|
9
|
+
|
10
|
+
# HNSWMemoryStore is an in-memory vector store that implements
|
11
|
+
# the HNSW algorithm.
|
12
|
+
class HNSWMemoryStore < Base
|
13
|
+
PROBABILITY_FACTORS = [
|
14
|
+
0.5,
|
15
|
+
1 / Math::E,
|
16
|
+
].freeze
|
17
|
+
|
18
|
+
# Initializes a new HNSWMemoryStore with the specified
|
19
|
+
# similarity metric, dimensions, m and ef.
|
20
|
+
#
|
21
|
+
# @param similarity_metric [Symbol] the similarity metric to use
|
22
|
+
# @param dimensions [Integer] the number of dimensions of the vectors
|
23
|
+
# @param m [Integer] the number of neighbors to consider when adding a node
|
24
|
+
# @param ef [Integer] the number of neighbors to consider when searching
|
25
|
+
# @raise [UnsupportedSimilarityMetricError] if the similarity metric is not supported
|
26
|
+
# @return [HNSWMemoryStore] the new HNSWMemoryStore
|
27
|
+
def initialize(similarity_metric, dimensions, m, ef)
|
28
|
+
@similarity_metric = similarity_metric
|
29
|
+
@dimensions = dimensions
|
30
|
+
@m = m
|
31
|
+
@ef = ef
|
32
|
+
@max_level = 0
|
33
|
+
@entrypoint = nil
|
34
|
+
@nodes = {}
|
35
|
+
end
|
36
|
+
|
37
|
+
delegate :size, to: :@nodes
|
38
|
+
|
39
|
+
attr_accessor :nodes
|
40
|
+
|
41
|
+
# Adds a new node to the vector store.
|
42
|
+
#
|
43
|
+
# @param node_id [String] the ID of the node
|
44
|
+
# @param vector [Array<Float>] the vector of the node
|
45
|
+
# @return [HNSWNode] the new node
|
46
|
+
def add_node(node_id, vector)
|
47
|
+
level = get_random_level
|
48
|
+
node = HNSWNode.new(node_id, vector, level, @m)
|
49
|
+
|
50
|
+
if @entrypoint.nil?
|
51
|
+
@entrypoint = node
|
52
|
+
return @nodes[node_id] = node
|
53
|
+
end
|
54
|
+
|
55
|
+
update_max_level(level)
|
56
|
+
current_node = search_level(vector, @entrypoint, @max_level)
|
57
|
+
|
58
|
+
@max_level.downto(0) do |i|
|
59
|
+
if i <= level
|
60
|
+
neighbors = find_neighbors(current_node, vector, i)
|
61
|
+
update_neighbors(node, neighbors, vector, i)
|
62
|
+
end
|
63
|
+
|
64
|
+
current_node = search_level(vector, @entrypoint, i - 1) if i > 0
|
65
|
+
end
|
66
|
+
|
67
|
+
@nodes[node_id] = node
|
68
|
+
end
|
69
|
+
|
70
|
+
alias_method :create_vector, :add_node
|
71
|
+
|
72
|
+
# Deletes a node from the vector store.
|
73
|
+
#
|
74
|
+
# @param node_id [String] the ID of the node
|
75
|
+
# @return [HNSWNode] the deleted node
|
76
|
+
def delete_node(node_id)
|
77
|
+
@nodes.delete(node_id)
|
78
|
+
end
|
79
|
+
|
80
|
+
alias_method :delete_vector, :delete_node
|
81
|
+
|
82
|
+
# Finds a node in the vector store.
|
83
|
+
#
|
84
|
+
# @param node_id [String] the ID of the node
|
85
|
+
# @return [HNSWNode] the found node
|
86
|
+
# @return [nil] if the node was not found
|
87
|
+
def find(node_id)
|
88
|
+
@nodes[node_id]
|
89
|
+
end
|
90
|
+
|
91
|
+
# Serializes the vector store to a binary string.
|
92
|
+
def serialize
|
93
|
+
graph = HNSWGraph.new(
|
94
|
+
entrypoint_id: @entrypoint.id,
|
95
|
+
max_level: @max_level,
|
96
|
+
similarity_metric: @similarity_metric,
|
97
|
+
dimensions: @dimensions,
|
98
|
+
m: @m,
|
99
|
+
ef: @ef,
|
100
|
+
nodes: @nodes.values.map do |node|
|
101
|
+
HNSWGraphNode.new(
|
102
|
+
id: node.id,
|
103
|
+
vector: node.vector,
|
104
|
+
level: node.level,
|
105
|
+
neighbors: node.neighbors.flatten.compact.map(&:id),
|
106
|
+
)
|
107
|
+
end,
|
108
|
+
)
|
109
|
+
|
110
|
+
graph.to_proto
|
111
|
+
end
|
112
|
+
|
113
|
+
# Deserializes a binary string into a vector store.
|
114
|
+
def self.deserialize(serialized_data)
|
115
|
+
graph = HNSWGraph.decode(serialized_data)
|
116
|
+
|
117
|
+
hnsw = new(graph.similarity_metric, graph.dimensions, graph.m, graph.ef)
|
118
|
+
|
119
|
+
# Create nodes
|
120
|
+
graph.nodes.each do |node|
|
121
|
+
hnsw.nodes[node.id] = HNSWNode.new(node.id, node.vector, node.level, graph.m)
|
122
|
+
end
|
123
|
+
|
124
|
+
# Set neighbors
|
125
|
+
graph.nodes.each do |node|
|
126
|
+
neighbors = node.neighbors.each_slice(graph.m).to_a
|
127
|
+
neighbors.each_with_index do |neighbor_ids, level|
|
128
|
+
neighbor_ids.each_with_index do |neighbor_id, index|
|
129
|
+
hnsw.nodes[node.id].neighbors[level][index] = hnsw.nodes[neighbor_id] if hnsw.nodes.key?(neighbor_id)
|
130
|
+
end
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
hnsw.instance_variable_set(:@entrypoint, hnsw.nodes[graph.entrypoint_id])
|
135
|
+
hnsw.instance_variable_set(:@max_level, graph.max_level)
|
136
|
+
|
137
|
+
hnsw
|
138
|
+
end
|
139
|
+
|
140
|
+
# Finds the nearest neighbors of a vector.
|
141
|
+
def find_neighbors(node, query, level)
|
142
|
+
search_knn(node, query, @m, level)
|
143
|
+
end
|
144
|
+
|
145
|
+
# Updates the neighbors of a node.
|
146
|
+
def update_neighbors(node, neighbors, query, level)
|
147
|
+
node.neighbors[level] = neighbors[0, @m]
|
148
|
+
|
149
|
+
neighbors.each do |neighbor|
|
150
|
+
n_distance = distance(neighbor.vector, query)
|
151
|
+
furthest_neighbor_index = neighbor.neighbors[level].index { |n| n.nil? || n_distance < distance(neighbor.vector, n.vector) }
|
152
|
+
next unless furthest_neighbor_index
|
153
|
+
|
154
|
+
neighbor.neighbors[level].insert(furthest_neighbor_index, node)
|
155
|
+
neighbor.neighbors[level].pop if neighbor.neighbors[level].size > @m
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
# Updates maximum level of the graph.
|
160
|
+
def update_max_level(level)
|
161
|
+
@max_level = level if level > @max_level
|
162
|
+
end
|
163
|
+
|
164
|
+
# Finds the k nearest neighbors of a vector.
|
165
|
+
def nearest_neighbors(query, k)
|
166
|
+
return [] unless @entrypoint
|
167
|
+
entry_point = @entrypoint
|
168
|
+
(0..@max_level).reverse_each do |level|
|
169
|
+
entry_point = search_level(query, entry_point, level)
|
170
|
+
end
|
171
|
+
search_knn(entry_point, query, k, 0)
|
172
|
+
end
|
173
|
+
|
174
|
+
def search_level(query, entry_point, level)
|
175
|
+
current = entry_point
|
176
|
+
best_distance = distance(query, current.vector)
|
177
|
+
|
178
|
+
loop do
|
179
|
+
closest_neighbor, closest_distance = find_closest_neighbor(query, current.neighbors[level])
|
180
|
+
|
181
|
+
if closest_neighbor && closest_distance < best_distance
|
182
|
+
best_distance = closest_distance
|
183
|
+
current = closest_neighbor
|
184
|
+
else
|
185
|
+
break
|
186
|
+
end
|
187
|
+
end
|
188
|
+
|
189
|
+
current
|
190
|
+
end
|
191
|
+
|
192
|
+
def find_closest_neighbor(query, neighbors)
|
193
|
+
closest_neighbor = nil
|
194
|
+
closest_distance = Float::INFINITY
|
195
|
+
|
196
|
+
neighbors.each do |neighbor|
|
197
|
+
next unless neighbor
|
198
|
+
distance = distance(query, neighbor.vector)
|
199
|
+
if distance < closest_distance
|
200
|
+
closest_distance = distance
|
201
|
+
closest_neighbor = neighbor
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
[closest_neighbor, closest_distance]
|
206
|
+
end
|
207
|
+
|
208
|
+
# Finds the k nearest neighbors of a vector.
|
209
|
+
def search_knn(entry_point, query, k, level)
|
210
|
+
visited = Set.new
|
211
|
+
candidates = Set.new([entry_point])
|
212
|
+
result = []
|
213
|
+
|
214
|
+
while candidates.size > 0
|
215
|
+
closest = find_closest_candidate(candidates, query)
|
216
|
+
candidates.delete(closest)
|
217
|
+
visited.add(closest.id)
|
218
|
+
|
219
|
+
result = update_result(result, closest, query, k)
|
220
|
+
|
221
|
+
break if termination_condition_met?(result, closest, query, k)
|
222
|
+
|
223
|
+
add_neighbors_to_candidates(closest, level, visited, candidates)
|
224
|
+
end
|
225
|
+
|
226
|
+
result
|
227
|
+
end
|
228
|
+
|
229
|
+
def find_closest_candidate(candidates, query)
|
230
|
+
candidates.min_by { |c| distance(query, c.vector) }
|
231
|
+
end
|
232
|
+
|
233
|
+
def update_result(result, candidate, query, k)
|
234
|
+
if result.size < k
|
235
|
+
result.push(candidate)
|
236
|
+
else
|
237
|
+
furthest_result = result.max_by { |r| distance(query, r.vector) }
|
238
|
+
closest_distance = distance(query, candidate.vector)
|
239
|
+
furthest_result_distance = distance(query, furthest_result.vector)
|
240
|
+
|
241
|
+
if closest_distance < furthest_result_distance
|
242
|
+
result.delete(furthest_result)
|
243
|
+
result.push(candidate)
|
244
|
+
end
|
245
|
+
end
|
246
|
+
result
|
247
|
+
end
|
248
|
+
|
249
|
+
def termination_condition_met?(result, closest, query, k)
|
250
|
+
return false if result.size < k
|
251
|
+
|
252
|
+
furthest_result_distance = distance(query, result.max_by { |r| distance(query, r.vector) }.vector)
|
253
|
+
closest_distance = distance(query, closest.vector)
|
254
|
+
|
255
|
+
closest_distance >= furthest_result_distance
|
256
|
+
end
|
257
|
+
|
258
|
+
def add_neighbors_to_candidates(closest, level, visited, candidates)
|
259
|
+
closest.neighbors[level].each do |neighbor|
|
260
|
+
next unless neighbor
|
261
|
+
next if visited.include?(neighbor.id)
|
262
|
+
candidates.add(neighbor)
|
263
|
+
end
|
264
|
+
end
|
265
|
+
|
266
|
+
# Calculates the distance between two vectors.
|
267
|
+
def distance(from, to)
|
268
|
+
case @similarity_metric.to_sym
|
269
|
+
when :euclidean
|
270
|
+
euclidean_distance(from, to)
|
271
|
+
when :cosine
|
272
|
+
cosine_distance(from, to)
|
273
|
+
else
|
274
|
+
raise UnsupportedSimilarityMetricError, "Similarity metric #{@similarity_metric} is not supported"
|
275
|
+
end
|
276
|
+
end
|
277
|
+
|
278
|
+
# Calculates the euclidean distance between two vectors.
|
279
|
+
def euclidean_distance(from, to)
|
280
|
+
e_distance = 0
|
281
|
+
from.each_with_index do |value, index|
|
282
|
+
e_distance += (value - to[index]) ** 2
|
283
|
+
end
|
284
|
+
|
285
|
+
Math.sqrt(e_distance)
|
286
|
+
end
|
287
|
+
|
288
|
+
# Calculates the cosine distance between two vectors.
|
289
|
+
def cosine_distance(from, to)
|
290
|
+
dot_product = 0
|
291
|
+
norm_from = 0
|
292
|
+
norm_to = 0
|
293
|
+
|
294
|
+
from.each_with_index do |value, index|
|
295
|
+
dot_product += value * to[index]
|
296
|
+
norm_from += value ** 2
|
297
|
+
norm_to += to[index] ** 2
|
298
|
+
end
|
299
|
+
|
300
|
+
1 - (dot_product / (Math.sqrt(norm_from) * Math.sqrt(norm_to)))
|
301
|
+
end
|
302
|
+
|
303
|
+
# Returns a random level for a node.
|
304
|
+
def get_random_level
|
305
|
+
level = 0
|
306
|
+
while rand < PROBABILITY_FACTORS[0] && level < @max_level
|
307
|
+
level += 1
|
308
|
+
end
|
309
|
+
level
|
310
|
+
end
|
311
|
+
|
312
|
+
# HNSW vector store node.
|
313
|
+
class HNSWNode
|
314
|
+
attr_reader :id, :vector
|
315
|
+
attr_accessor :level, :neighbors
|
316
|
+
|
317
|
+
# Initializes a new node.
|
318
|
+
#
|
319
|
+
# @param id [String] the node ID (ULID)
|
320
|
+
# @param vector [Array] the node vector
|
321
|
+
# @param level [Integer] the node level
|
322
|
+
# @param m [Integer] the number of neighbors
|
323
|
+
# @return [HNSWNode] the node
|
324
|
+
def initialize(id, vector, level, m)
|
325
|
+
@id = id
|
326
|
+
@vector = vector
|
327
|
+
@level = level
|
328
|
+
@neighbors = Array.new(level + 1) { Array.new(m) }
|
329
|
+
end
|
330
|
+
end
|
331
|
+
|
332
|
+
# BoundedPriorityQueue is a data structure that keeps a priority queue
|
333
|
+
# of a bounded size. It maintains the top-k elements with the smallest
|
334
|
+
# priorities. It uses an underlying PriorityQueue to store elements.
|
335
|
+
class BoundedPriorityQueue
|
336
|
+
def initialize(max_size)
|
337
|
+
@max_size = max_size
|
338
|
+
@queue = PriorityQueue.new
|
339
|
+
end
|
340
|
+
|
341
|
+
def size
|
342
|
+
@queue.size
|
343
|
+
end
|
344
|
+
|
345
|
+
# Inserts an item into the BoundedPriorityQueue. If the queue is full
|
346
|
+
# and the new item has a smaller priority than the item with the
|
347
|
+
# highest priority, the highest priority item is removed and the new
|
348
|
+
# item is added.
|
349
|
+
def push(item)
|
350
|
+
if size < @max_size
|
351
|
+
@queue.push(item)
|
352
|
+
elsif item[0] < @queue.peek[0]
|
353
|
+
@queue.pop
|
354
|
+
@queue.push(item)
|
355
|
+
end
|
356
|
+
end
|
357
|
+
|
358
|
+
# Returns the item with the smallest priority without removing it from
|
359
|
+
# the BoundedPriorityQueue.
|
360
|
+
def peek
|
361
|
+
@queue.peek
|
362
|
+
end
|
363
|
+
|
364
|
+
def to_a
|
365
|
+
@queue.to_a
|
366
|
+
end
|
367
|
+
end
|
368
|
+
|
369
|
+
# PriorityQueue is a data structure that keeps elements ordered by priority.
|
370
|
+
# It supports inserting elements, removing the element with the smallest
|
371
|
+
# priority, and peeking at the element with the smallest priority. It uses
|
372
|
+
# a binary heap as the underlying data structure.
|
373
|
+
class PriorityQueue
|
374
|
+
def initialize
|
375
|
+
@elements = []
|
376
|
+
end
|
377
|
+
|
378
|
+
def size
|
379
|
+
@elements.size
|
380
|
+
end
|
381
|
+
|
382
|
+
def empty?
|
383
|
+
@elements.empty?
|
384
|
+
end
|
385
|
+
|
386
|
+
def push(item)
|
387
|
+
@elements << item
|
388
|
+
shift_up(@elements.size - 1)
|
389
|
+
end
|
390
|
+
|
391
|
+
# Removes and returns the element with the smallest priority.
|
392
|
+
# Returns nil if the PriorityQueue is empty.
|
393
|
+
def pop
|
394
|
+
return if empty?
|
395
|
+
|
396
|
+
swap(0, @elements.size - 1)
|
397
|
+
element = @elements.pop
|
398
|
+
shift_down(0)
|
399
|
+
element
|
400
|
+
end
|
401
|
+
|
402
|
+
# Returns the element with the smallest priority without removing it
|
403
|
+
# from the PriorityQueue.
|
404
|
+
def peek
|
405
|
+
@elements.first
|
406
|
+
end
|
407
|
+
|
408
|
+
def to_a
|
409
|
+
@elements.dup
|
410
|
+
end
|
411
|
+
|
412
|
+
private
|
413
|
+
|
414
|
+
def swap(i, j)
|
415
|
+
@elements[i], @elements[j] = @elements[j], @elements[i]
|
416
|
+
end
|
417
|
+
|
418
|
+
def shift_up(i)
|
419
|
+
parent = (i - 1) / 2
|
420
|
+
return if i <= 0 || @elements[parent][0] <= @elements[i][0]
|
421
|
+
|
422
|
+
swap(i, parent)
|
423
|
+
shift_up(parent)
|
424
|
+
end
|
425
|
+
|
426
|
+
def shift_down(i)
|
427
|
+
left_child = 2 * i + 1
|
428
|
+
right_child = 2 * i + 2
|
429
|
+
|
430
|
+
min = i
|
431
|
+
min = left_child if left_child < size && @elements[left_child][0] < @elements[min][0]
|
432
|
+
min = right_child if right_child < size && @elements[right_child][0] < @elements[min][0]
|
433
|
+
|
434
|
+
return if min == i
|
435
|
+
|
436
|
+
swap(i, min)
|
437
|
+
shift_down(min)
|
438
|
+
end
|
439
|
+
end
|
440
|
+
end
|
441
|
+
end
|
442
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# source: hnsw.proto
|
3
|
+
|
4
|
+
require "google/protobuf"
|
5
|
+
|
6
|
+
Google::Protobuf::DescriptorPool.generated_pool.build do
|
7
|
+
add_file("hnsw.proto", syntax: :proto3) do
|
8
|
+
add_message "HNSWGraphNode" do
|
9
|
+
optional :id, :string, 1
|
10
|
+
repeated :vector, :float, 2
|
11
|
+
optional :level, :int32, 3
|
12
|
+
repeated :neighbors, :string, 4
|
13
|
+
end
|
14
|
+
add_message "HNSWGraph" do
|
15
|
+
optional :entrypoint_id, :string, 1
|
16
|
+
optional :max_level, :int32, 2
|
17
|
+
optional :similarity_metric, :string, 3
|
18
|
+
optional :dimensions, :int32, 4
|
19
|
+
optional :m, :int32, 5
|
20
|
+
optional :ef, :int32, 6
|
21
|
+
repeated :nodes, :message, 7, "HNSWGraphNode"
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
HNSWGraphNode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraphNode").msgclass
|
27
|
+
HNSWGraph = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraph").msgclass
|