roseflow 0.0.1 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.standard.yml +3 -1
- data/CHANGELOG.md +2 -2
- data/Gemfile +2 -0
- data/examples/github-repo-chat/lib/actions/clone_and_load_repository.rb +52 -0
- data/examples/github-repo-chat/lib/actions/create_prompt.rb +15 -0
- data/examples/github-repo-chat/lib/actions/embed_repository.rb +35 -0
- data/examples/github-repo-chat/lib/actions/initialize_vector_store.rb +40 -0
- data/examples/github-repo-chat/lib/actions/interact_with_model.rb +29 -0
- data/examples/github-repo-chat/lib/actions/load_documents_to_database.rb +0 -0
- data/examples/github-repo-chat/lib/actions/split_files_to_documents.rb +55 -0
- data/examples/github-repo-chat/lib/document_database.rb +0 -0
- data/examples/github-repo-chat/lib/github_chat_prompt.rb +24 -0
- data/examples/github-repo-chat/lib/github_repository_chat.rb +12 -0
- data/examples/github-repo-chat/lib/interactions/ask_llm.rb +31 -0
- data/examples/github-repo-chat/lib/interactions/github_repository_chat.rb +36 -0
- data/examples/github-repo-chat/lib/interactions/load_files_to_document_database.rb +18 -0
- data/examples/github-repo-chat/lib/interactions/load_repository.rb +20 -0
- data/examples/github-repo-chat/lib/interactions/prepare_vector_store.rb +21 -0
- data/examples/github-repo-chat/lib/repository.rb +9 -0
- data/examples/github-repo-chat/lib/repository_file.rb +31 -0
- data/examples/github-repo-chat/spec/actions/clone_and_load_repository_spec.rb +28 -0
- data/examples/github-repo-chat/spec/actions/embed_repository_spec.rb +24 -0
- data/examples/github-repo-chat/spec/actions/initialize_vector_store_spec.rb +20 -0
- data/examples/github-repo-chat/spec/actions/load_files_to_document_database_spec.rb +23 -0
- data/examples/github-repo-chat/spec/fixtures/ulid-ruby.zip +0 -0
- data/examples/github-repo-chat/spec/github_repository_chat_spec.rb +16 -0
- data/examples/github-repo-chat/spec/interactions/prepare_vector_store_spec.rb +4 -0
- data/examples/github-repo-chat/spec/spec_helper.rb +12 -0
- data/lib/roseflow/action.rb +13 -0
- data/lib/roseflow/actions/ai/resolve_model.rb +27 -0
- data/lib/roseflow/actions/ai/resolve_provider.rb +31 -0
- data/lib/roseflow/ai/model.rb +19 -0
- data/lib/roseflow/ai/provider.rb +30 -0
- data/lib/roseflow/chat/dialogue.rb +80 -0
- data/lib/roseflow/chat/exchange.rb +12 -0
- data/lib/roseflow/chat/message.rb +39 -0
- data/lib/roseflow/chat/personality.rb +10 -0
- data/lib/roseflow/embeddings/embedding.rb +26 -0
- data/lib/roseflow/finite_machine.rb +298 -0
- data/lib/roseflow/interaction/with_http_api.rb +10 -0
- data/lib/roseflow/interaction.rb +14 -0
- data/lib/roseflow/interaction_context.rb +10 -0
- data/lib/roseflow/interactions/ai/initialize_llm.rb +26 -0
- data/lib/roseflow/primitives/vector.rb +19 -0
- data/lib/roseflow/prompt.rb +17 -0
- data/lib/roseflow/text/completion.rb +16 -0
- data/lib/roseflow/text/recursive_character_splitter.rb +43 -0
- data/lib/roseflow/text/sentence_splitter.rb +42 -0
- data/lib/roseflow/text/splitter.rb +18 -0
- data/lib/roseflow/text/tokenized_text.rb +20 -0
- data/lib/roseflow/text/word_splitter.rb +14 -0
- data/lib/roseflow/tokenizer.rb +13 -0
- data/lib/roseflow/types.rb +9 -0
- data/lib/roseflow/vector_stores/base.rb +39 -0
- data/lib/roseflow/vector_stores/hnsw.proto +18 -0
- data/lib/roseflow/vector_stores/hnsw_memory_store.rb +442 -0
- data/lib/roseflow/vector_stores/hnsw_pb.rb +27 -0
- data/lib/roseflow/vector_stores/type/vector.rb +38 -0
- data/lib/roseflow/vector_stores/vector.rb +19 -0
- data/lib/roseflow/version.rb +12 -1
- data/lib/roseflow.rb +10 -1
- data/roseflow.gemspec +53 -0
- metadata +274 -7
@@ -0,0 +1,43 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "roseflow/text/splitter"
|
4
|
+
|
5
|
+
module Roseflow
|
6
|
+
module Text
|
7
|
+
class RecursiveCharacterSplitter < Splitter
|
8
|
+
SEPARATORS = ["\n\n", "\n", " ", ""]
|
9
|
+
|
10
|
+
def initialize(separators = nil, **kwargs)
|
11
|
+
super(**kwargs)
|
12
|
+
@separators = separators || SEPARATORS
|
13
|
+
end
|
14
|
+
|
15
|
+
attr_reader :chunk_size, :chunk_overlap
|
16
|
+
|
17
|
+
def split(text)
|
18
|
+
segments = text.split(find_separator(text))
|
19
|
+
current_size = 0
|
20
|
+
results = [[]]
|
21
|
+
|
22
|
+
segments.each do |segment|
|
23
|
+
if current_size + segment.size > chunk_size
|
24
|
+
overlap = [results.last.last(chunk_overlap), segment].flatten
|
25
|
+
current_size = overlap.sum(&:size) + chunk_overlap
|
26
|
+
results << overlap
|
27
|
+
else
|
28
|
+
current_size += segment.size + results.last.size
|
29
|
+
results.last << segment
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
results.map { |r| r.join(" ") }
|
34
|
+
end
|
35
|
+
|
36
|
+
private
|
37
|
+
|
38
|
+
def find_separator(text)
|
39
|
+
@separators.find { |separator| text.include?(separator) } || @separators.last
|
40
|
+
end
|
41
|
+
end # RecursiveCharacterSplitter
|
42
|
+
end # Text
|
43
|
+
end # Roseflow
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "pragmatic_segmenter"
|
4
|
+
require "roseflow/text/splitter"
|
5
|
+
|
6
|
+
module Roseflow
|
7
|
+
module Text
|
8
|
+
class SentenceSplitter < Splitter
|
9
|
+
def initialize(language: "en", **kwargs)
|
10
|
+
super(**kwargs)
|
11
|
+
@language = language
|
12
|
+
end
|
13
|
+
|
14
|
+
attr_reader :chunk_size, :chunk_overlap
|
15
|
+
|
16
|
+
def segmenter(text)
|
17
|
+
@segmenter ||= PragmaticSegmenter::Segmenter.new(text: text, language: @language)
|
18
|
+
end
|
19
|
+
|
20
|
+
def split(text)
|
21
|
+
segments = segmenter(text).segment
|
22
|
+
current_size = 0
|
23
|
+
results = [[]]
|
24
|
+
|
25
|
+
segments.each do |segment|
|
26
|
+
if current_size + segment.size > chunk_size
|
27
|
+
overlap = [results.last.last(chunk_overlap), segment].flatten
|
28
|
+
current_size = overlap.sum(&:size) + chunk_overlap
|
29
|
+
results << overlap
|
30
|
+
else
|
31
|
+
current_size += segment.size + results.last.size
|
32
|
+
results.last << segment
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
results.map { |r| r.join(" ") }
|
37
|
+
end
|
38
|
+
|
39
|
+
private
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Roseflow
|
4
|
+
module Text
|
5
|
+
class Splitter
|
6
|
+
def initialize(chunk_size:, chunk_overlap:)
|
7
|
+
raise ArgumentError, "chunk overlap cannot exceed chunk size" if chunk_overlap > chunk_size
|
8
|
+
|
9
|
+
@chunk_size = chunk_size
|
10
|
+
@chunk_overlap = chunk_overlap
|
11
|
+
end
|
12
|
+
|
13
|
+
def split(text)
|
14
|
+
raise NotImplementedError, "this class must be extended and the #split method implemented"
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "dry-struct"
|
4
|
+
|
5
|
+
module Roseflow
|
6
|
+
module Text
|
7
|
+
class TokenizedText < Dry::Struct
|
8
|
+
attribute :text, Types::String
|
9
|
+
attribute :tokens, Types::Array.of(Types::Integer)
|
10
|
+
|
11
|
+
def token_count
|
12
|
+
tokens.count
|
13
|
+
end
|
14
|
+
|
15
|
+
def to_s
|
16
|
+
text
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "roseflow/text/recursive_character_splitter"
|
4
|
+
|
5
|
+
module Roseflow
|
6
|
+
module Text
|
7
|
+
class WordSplitter < RecursiveCharacterSplitter
|
8
|
+
def initialize(separators = [" "], **kwargs)
|
9
|
+
super(**kwargs)
|
10
|
+
@separators = separators
|
11
|
+
end
|
12
|
+
end # WordSplitter
|
13
|
+
end # Text
|
14
|
+
end # Roseflow
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Roseflow
|
4
|
+
class Tokenizer
|
5
|
+
def encode(input)
|
6
|
+
raise NotImplementedError, "this class must be extended and the #encode method implemented"
|
7
|
+
end
|
8
|
+
|
9
|
+
def decode(input)
|
10
|
+
raise NotImplementedError, "this class must be extended and the #decode method implemented"
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Roseflow
|
4
|
+
module VectorStores
|
5
|
+
class Base
|
6
|
+
def has_embeddings?(klass)
|
7
|
+
false
|
8
|
+
end
|
9
|
+
|
10
|
+
def list_vectors(namespace = nil)
|
11
|
+
raise NotImplementedError, "You must implement the #list_vectors method in your vector store"
|
12
|
+
end
|
13
|
+
|
14
|
+
def build_vector(id, attrs)
|
15
|
+
raise NotImplementedError, "You must implement the #build_vector method in your vector store"
|
16
|
+
end
|
17
|
+
|
18
|
+
def create_vector(vector)
|
19
|
+
raise NotImplementedError, "You must implement the #create_vector method in your vector store"
|
20
|
+
end
|
21
|
+
|
22
|
+
def delete_vector(name)
|
23
|
+
raise NotImplementedError, "You must implement the #delete_vector method in your vector store"
|
24
|
+
end
|
25
|
+
|
26
|
+
def update_vector(vector)
|
27
|
+
raise NotImplementedError, "You must implement the #update_vector method in your vector store"
|
28
|
+
end
|
29
|
+
|
30
|
+
def query(query)
|
31
|
+
raise NotImplementedError, "You must implement the #query method in your vector store"
|
32
|
+
end
|
33
|
+
|
34
|
+
def find(query)
|
35
|
+
raise NotImplementedError, "You must implement the #find method in your vector store"
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
syntax = "proto3";
|
2
|
+
|
3
|
+
message HNSWGraphNode {
|
4
|
+
string id = 1;
|
5
|
+
repeated float vector = 2;
|
6
|
+
int32 level = 3;
|
7
|
+
repeated string neighbors = 4;
|
8
|
+
}
|
9
|
+
|
10
|
+
message HNSWGraph {
|
11
|
+
string entrypoint_id = 1;
|
12
|
+
int32 max_level = 2;
|
13
|
+
string similarity_metric = 3;
|
14
|
+
int32 dimensions = 4;
|
15
|
+
int32 m = 5;
|
16
|
+
int32 ef = 6;
|
17
|
+
repeated HNSWGraphNode nodes = 7;
|
18
|
+
}
|
@@ -0,0 +1,442 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "ulid"
|
4
|
+
require "roseflow/vector_stores/hnsw_pb"
|
5
|
+
|
6
|
+
module Roseflow
|
7
|
+
module VectorStores
|
8
|
+
UnsupportedSimilarityMetricError = Class.new(StandardError)
|
9
|
+
|
10
|
+
# HNSWMemoryStore is an in-memory vector store that implements
|
11
|
+
# the HNSW algorithm.
|
12
|
+
class HNSWMemoryStore < Base
|
13
|
+
PROBABILITY_FACTORS = [
|
14
|
+
0.5,
|
15
|
+
1 / Math::E,
|
16
|
+
].freeze
|
17
|
+
|
18
|
+
# Initializes a new HNSWMemoryStore with the specified
|
19
|
+
# similarity metric, dimensions, m and ef.
|
20
|
+
#
|
21
|
+
# @param similarity_metric [Symbol] the similarity metric to use
|
22
|
+
# @param dimensions [Integer] the number of dimensions of the vectors
|
23
|
+
# @param m [Integer] the number of neighbors to consider when adding a node
|
24
|
+
# @param ef [Integer] the number of neighbors to consider when searching
|
25
|
+
# @raise [UnsupportedSimilarityMetricError] if the similarity metric is not supported
|
26
|
+
# @return [HNSWMemoryStore] the new HNSWMemoryStore
|
27
|
+
def initialize(similarity_metric, dimensions, m, ef)
|
28
|
+
@similarity_metric = similarity_metric
|
29
|
+
@dimensions = dimensions
|
30
|
+
@m = m
|
31
|
+
@ef = ef
|
32
|
+
@max_level = 0
|
33
|
+
@entrypoint = nil
|
34
|
+
@nodes = {}
|
35
|
+
end
|
36
|
+
|
37
|
+
delegate :size, to: :@nodes
|
38
|
+
|
39
|
+
attr_accessor :nodes
|
40
|
+
|
41
|
+
# Adds a new node to the vector store.
|
42
|
+
#
|
43
|
+
# @param node_id [String] the ID of the node
|
44
|
+
# @param vector [Array<Float>] the vector of the node
|
45
|
+
# @return [HNSWNode] the new node
|
46
|
+
def add_node(node_id, vector)
|
47
|
+
level = get_random_level
|
48
|
+
node = HNSWNode.new(node_id, vector, level, @m)
|
49
|
+
|
50
|
+
if @entrypoint.nil?
|
51
|
+
@entrypoint = node
|
52
|
+
return @nodes[node_id] = node
|
53
|
+
end
|
54
|
+
|
55
|
+
update_max_level(level)
|
56
|
+
current_node = search_level(vector, @entrypoint, @max_level)
|
57
|
+
|
58
|
+
@max_level.downto(0) do |i|
|
59
|
+
if i <= level
|
60
|
+
neighbors = find_neighbors(current_node, vector, i)
|
61
|
+
update_neighbors(node, neighbors, vector, i)
|
62
|
+
end
|
63
|
+
|
64
|
+
current_node = search_level(vector, @entrypoint, i - 1) if i > 0
|
65
|
+
end
|
66
|
+
|
67
|
+
@nodes[node_id] = node
|
68
|
+
end
|
69
|
+
|
70
|
+
alias_method :create_vector, :add_node
|
71
|
+
|
72
|
+
# Deletes a node from the vector store.
|
73
|
+
#
|
74
|
+
# @param node_id [String] the ID of the node
|
75
|
+
# @return [HNSWNode] the deleted node
|
76
|
+
def delete_node(node_id)
|
77
|
+
@nodes.delete(node_id)
|
78
|
+
end
|
79
|
+
|
80
|
+
alias_method :delete_vector, :delete_node
|
81
|
+
|
82
|
+
# Finds a node in the vector store.
|
83
|
+
#
|
84
|
+
# @param node_id [String] the ID of the node
|
85
|
+
# @return [HNSWNode] the found node
|
86
|
+
# @return [nil] if the node was not found
|
87
|
+
def find(node_id)
|
88
|
+
@nodes[node_id]
|
89
|
+
end
|
90
|
+
|
91
|
+
# Serializes the vector store to a binary string.
|
92
|
+
def serialize
|
93
|
+
graph = HNSWGraph.new(
|
94
|
+
entrypoint_id: @entrypoint.id,
|
95
|
+
max_level: @max_level,
|
96
|
+
similarity_metric: @similarity_metric,
|
97
|
+
dimensions: @dimensions,
|
98
|
+
m: @m,
|
99
|
+
ef: @ef,
|
100
|
+
nodes: @nodes.values.map do |node|
|
101
|
+
HNSWGraphNode.new(
|
102
|
+
id: node.id,
|
103
|
+
vector: node.vector,
|
104
|
+
level: node.level,
|
105
|
+
neighbors: node.neighbors.flatten.compact.map(&:id),
|
106
|
+
)
|
107
|
+
end,
|
108
|
+
)
|
109
|
+
|
110
|
+
graph.to_proto
|
111
|
+
end
|
112
|
+
|
113
|
+
# Deserializes a binary string into a vector store.
|
114
|
+
def self.deserialize(serialized_data)
|
115
|
+
graph = HNSWGraph.decode(serialized_data)
|
116
|
+
|
117
|
+
hnsw = new(graph.similarity_metric, graph.dimensions, graph.m, graph.ef)
|
118
|
+
|
119
|
+
# Create nodes
|
120
|
+
graph.nodes.each do |node|
|
121
|
+
hnsw.nodes[node.id] = HNSWNode.new(node.id, node.vector, node.level, graph.m)
|
122
|
+
end
|
123
|
+
|
124
|
+
# Set neighbors
|
125
|
+
graph.nodes.each do |node|
|
126
|
+
neighbors = node.neighbors.each_slice(graph.m).to_a
|
127
|
+
neighbors.each_with_index do |neighbor_ids, level|
|
128
|
+
neighbor_ids.each_with_index do |neighbor_id, index|
|
129
|
+
hnsw.nodes[node.id].neighbors[level][index] = hnsw.nodes[neighbor_id] if hnsw.nodes.key?(neighbor_id)
|
130
|
+
end
|
131
|
+
end
|
132
|
+
end
|
133
|
+
|
134
|
+
hnsw.instance_variable_set(:@entrypoint, hnsw.nodes[graph.entrypoint_id])
|
135
|
+
hnsw.instance_variable_set(:@max_level, graph.max_level)
|
136
|
+
|
137
|
+
hnsw
|
138
|
+
end
|
139
|
+
|
140
|
+
# Finds the nearest neighbors of a vector.
|
141
|
+
def find_neighbors(node, query, level)
|
142
|
+
search_knn(node, query, @m, level)
|
143
|
+
end
|
144
|
+
|
145
|
+
# Updates the neighbors of a node.
|
146
|
+
def update_neighbors(node, neighbors, query, level)
|
147
|
+
node.neighbors[level] = neighbors[0, @m]
|
148
|
+
|
149
|
+
neighbors.each do |neighbor|
|
150
|
+
n_distance = distance(neighbor.vector, query)
|
151
|
+
furthest_neighbor_index = neighbor.neighbors[level].index { |n| n.nil? || n_distance < distance(neighbor.vector, n.vector) }
|
152
|
+
next unless furthest_neighbor_index
|
153
|
+
|
154
|
+
neighbor.neighbors[level].insert(furthest_neighbor_index, node)
|
155
|
+
neighbor.neighbors[level].pop if neighbor.neighbors[level].size > @m
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
# Updates maximum level of the graph.
|
160
|
+
def update_max_level(level)
|
161
|
+
@max_level = level if level > @max_level
|
162
|
+
end
|
163
|
+
|
164
|
+
# Finds the k nearest neighbors of a vector.
|
165
|
+
def nearest_neighbors(query, k)
|
166
|
+
return [] unless @entrypoint
|
167
|
+
entry_point = @entrypoint
|
168
|
+
(0..@max_level).reverse_each do |level|
|
169
|
+
entry_point = search_level(query, entry_point, level)
|
170
|
+
end
|
171
|
+
search_knn(entry_point, query, k, 0)
|
172
|
+
end
|
173
|
+
|
174
|
+
def search_level(query, entry_point, level)
|
175
|
+
current = entry_point
|
176
|
+
best_distance = distance(query, current.vector)
|
177
|
+
|
178
|
+
loop do
|
179
|
+
closest_neighbor, closest_distance = find_closest_neighbor(query, current.neighbors[level])
|
180
|
+
|
181
|
+
if closest_neighbor && closest_distance < best_distance
|
182
|
+
best_distance = closest_distance
|
183
|
+
current = closest_neighbor
|
184
|
+
else
|
185
|
+
break
|
186
|
+
end
|
187
|
+
end
|
188
|
+
|
189
|
+
current
|
190
|
+
end
|
191
|
+
|
192
|
+
def find_closest_neighbor(query, neighbors)
|
193
|
+
closest_neighbor = nil
|
194
|
+
closest_distance = Float::INFINITY
|
195
|
+
|
196
|
+
neighbors.each do |neighbor|
|
197
|
+
next unless neighbor
|
198
|
+
distance = distance(query, neighbor.vector)
|
199
|
+
if distance < closest_distance
|
200
|
+
closest_distance = distance
|
201
|
+
closest_neighbor = neighbor
|
202
|
+
end
|
203
|
+
end
|
204
|
+
|
205
|
+
[closest_neighbor, closest_distance]
|
206
|
+
end
|
207
|
+
|
208
|
+
# Finds the k nearest neighbors of a vector.
|
209
|
+
def search_knn(entry_point, query, k, level)
|
210
|
+
visited = Set.new
|
211
|
+
candidates = Set.new([entry_point])
|
212
|
+
result = []
|
213
|
+
|
214
|
+
while candidates.size > 0
|
215
|
+
closest = find_closest_candidate(candidates, query)
|
216
|
+
candidates.delete(closest)
|
217
|
+
visited.add(closest.id)
|
218
|
+
|
219
|
+
result = update_result(result, closest, query, k)
|
220
|
+
|
221
|
+
break if termination_condition_met?(result, closest, query, k)
|
222
|
+
|
223
|
+
add_neighbors_to_candidates(closest, level, visited, candidates)
|
224
|
+
end
|
225
|
+
|
226
|
+
result
|
227
|
+
end
|
228
|
+
|
229
|
+
def find_closest_candidate(candidates, query)
|
230
|
+
candidates.min_by { |c| distance(query, c.vector) }
|
231
|
+
end
|
232
|
+
|
233
|
+
def update_result(result, candidate, query, k)
|
234
|
+
if result.size < k
|
235
|
+
result.push(candidate)
|
236
|
+
else
|
237
|
+
furthest_result = result.max_by { |r| distance(query, r.vector) }
|
238
|
+
closest_distance = distance(query, candidate.vector)
|
239
|
+
furthest_result_distance = distance(query, furthest_result.vector)
|
240
|
+
|
241
|
+
if closest_distance < furthest_result_distance
|
242
|
+
result.delete(furthest_result)
|
243
|
+
result.push(candidate)
|
244
|
+
end
|
245
|
+
end
|
246
|
+
result
|
247
|
+
end
|
248
|
+
|
249
|
+
def termination_condition_met?(result, closest, query, k)
|
250
|
+
return false if result.size < k
|
251
|
+
|
252
|
+
furthest_result_distance = distance(query, result.max_by { |r| distance(query, r.vector) }.vector)
|
253
|
+
closest_distance = distance(query, closest.vector)
|
254
|
+
|
255
|
+
closest_distance >= furthest_result_distance
|
256
|
+
end
|
257
|
+
|
258
|
+
def add_neighbors_to_candidates(closest, level, visited, candidates)
|
259
|
+
closest.neighbors[level].each do |neighbor|
|
260
|
+
next unless neighbor
|
261
|
+
next if visited.include?(neighbor.id)
|
262
|
+
candidates.add(neighbor)
|
263
|
+
end
|
264
|
+
end
|
265
|
+
|
266
|
+
# Calculates the distance between two vectors.
|
267
|
+
def distance(from, to)
|
268
|
+
case @similarity_metric.to_sym
|
269
|
+
when :euclidean
|
270
|
+
euclidean_distance(from, to)
|
271
|
+
when :cosine
|
272
|
+
cosine_distance(from, to)
|
273
|
+
else
|
274
|
+
raise UnsupportedSimilarityMetricError, "Similarity metric #{@similarity_metric} is not supported"
|
275
|
+
end
|
276
|
+
end
|
277
|
+
|
278
|
+
# Calculates the euclidean distance between two vectors.
|
279
|
+
def euclidean_distance(from, to)
|
280
|
+
e_distance = 0
|
281
|
+
from.each_with_index do |value, index|
|
282
|
+
e_distance += (value - to[index]) ** 2
|
283
|
+
end
|
284
|
+
|
285
|
+
Math.sqrt(e_distance)
|
286
|
+
end
|
287
|
+
|
288
|
+
# Calculates the cosine distance between two vectors.
|
289
|
+
def cosine_distance(from, to)
|
290
|
+
dot_product = 0
|
291
|
+
norm_from = 0
|
292
|
+
norm_to = 0
|
293
|
+
|
294
|
+
from.each_with_index do |value, index|
|
295
|
+
dot_product += value * to[index]
|
296
|
+
norm_from += value ** 2
|
297
|
+
norm_to += to[index] ** 2
|
298
|
+
end
|
299
|
+
|
300
|
+
1 - (dot_product / (Math.sqrt(norm_from) * Math.sqrt(norm_to)))
|
301
|
+
end
|
302
|
+
|
303
|
+
# Returns a random level for a node.
|
304
|
+
def get_random_level
|
305
|
+
level = 0
|
306
|
+
while rand < PROBABILITY_FACTORS[0] && level < @max_level
|
307
|
+
level += 1
|
308
|
+
end
|
309
|
+
level
|
310
|
+
end
|
311
|
+
|
312
|
+
# HNSW vector store node.
|
313
|
+
class HNSWNode
|
314
|
+
attr_reader :id, :vector
|
315
|
+
attr_accessor :level, :neighbors
|
316
|
+
|
317
|
+
# Initializes a new node.
|
318
|
+
#
|
319
|
+
# @param id [String] the node ID (ULID)
|
320
|
+
# @param vector [Array] the node vector
|
321
|
+
# @param level [Integer] the node level
|
322
|
+
# @param m [Integer] the number of neighbors
|
323
|
+
# @return [HNSWNode] the node
|
324
|
+
def initialize(id, vector, level, m)
|
325
|
+
@id = id
|
326
|
+
@vector = vector
|
327
|
+
@level = level
|
328
|
+
@neighbors = Array.new(level + 1) { Array.new(m) }
|
329
|
+
end
|
330
|
+
end
|
331
|
+
|
332
|
+
# BoundedPriorityQueue is a data structure that keeps a priority queue
|
333
|
+
# of a bounded size. It maintains the top-k elements with the smallest
|
334
|
+
# priorities. It uses an underlying PriorityQueue to store elements.
|
335
|
+
class BoundedPriorityQueue
|
336
|
+
def initialize(max_size)
|
337
|
+
@max_size = max_size
|
338
|
+
@queue = PriorityQueue.new
|
339
|
+
end
|
340
|
+
|
341
|
+
def size
|
342
|
+
@queue.size
|
343
|
+
end
|
344
|
+
|
345
|
+
# Inserts an item into the BoundedPriorityQueue. If the queue is full
|
346
|
+
# and the new item has a smaller priority than the item with the
|
347
|
+
# highest priority, the highest priority item is removed and the new
|
348
|
+
# item is added.
|
349
|
+
def push(item)
|
350
|
+
if size < @max_size
|
351
|
+
@queue.push(item)
|
352
|
+
elsif item[0] < @queue.peek[0]
|
353
|
+
@queue.pop
|
354
|
+
@queue.push(item)
|
355
|
+
end
|
356
|
+
end
|
357
|
+
|
358
|
+
# Returns the item with the smallest priority without removing it from
|
359
|
+
# the BoundedPriorityQueue.
|
360
|
+
def peek
|
361
|
+
@queue.peek
|
362
|
+
end
|
363
|
+
|
364
|
+
def to_a
|
365
|
+
@queue.to_a
|
366
|
+
end
|
367
|
+
end
|
368
|
+
|
369
|
+
# PriorityQueue is a data structure that keeps elements ordered by priority.
|
370
|
+
# It supports inserting elements, removing the element with the smallest
|
371
|
+
# priority, and peeking at the element with the smallest priority. It uses
|
372
|
+
# a binary heap as the underlying data structure.
|
373
|
+
class PriorityQueue
|
374
|
+
def initialize
|
375
|
+
@elements = []
|
376
|
+
end
|
377
|
+
|
378
|
+
def size
|
379
|
+
@elements.size
|
380
|
+
end
|
381
|
+
|
382
|
+
def empty?
|
383
|
+
@elements.empty?
|
384
|
+
end
|
385
|
+
|
386
|
+
def push(item)
|
387
|
+
@elements << item
|
388
|
+
shift_up(@elements.size - 1)
|
389
|
+
end
|
390
|
+
|
391
|
+
# Removes and returns the element with the smallest priority.
|
392
|
+
# Returns nil if the PriorityQueue is empty.
|
393
|
+
def pop
|
394
|
+
return if empty?
|
395
|
+
|
396
|
+
swap(0, @elements.size - 1)
|
397
|
+
element = @elements.pop
|
398
|
+
shift_down(0)
|
399
|
+
element
|
400
|
+
end
|
401
|
+
|
402
|
+
# Returns the element with the smallest priority without removing it
|
403
|
+
# from the PriorityQueue.
|
404
|
+
def peek
|
405
|
+
@elements.first
|
406
|
+
end
|
407
|
+
|
408
|
+
def to_a
|
409
|
+
@elements.dup
|
410
|
+
end
|
411
|
+
|
412
|
+
private
|
413
|
+
|
414
|
+
def swap(i, j)
|
415
|
+
@elements[i], @elements[j] = @elements[j], @elements[i]
|
416
|
+
end
|
417
|
+
|
418
|
+
def shift_up(i)
|
419
|
+
parent = (i - 1) / 2
|
420
|
+
return if i <= 0 || @elements[parent][0] <= @elements[i][0]
|
421
|
+
|
422
|
+
swap(i, parent)
|
423
|
+
shift_up(parent)
|
424
|
+
end
|
425
|
+
|
426
|
+
def shift_down(i)
|
427
|
+
left_child = 2 * i + 1
|
428
|
+
right_child = 2 * i + 2
|
429
|
+
|
430
|
+
min = i
|
431
|
+
min = left_child if left_child < size && @elements[left_child][0] < @elements[min][0]
|
432
|
+
min = right_child if right_child < size && @elements[right_child][0] < @elements[min][0]
|
433
|
+
|
434
|
+
return if min == i
|
435
|
+
|
436
|
+
swap(i, min)
|
437
|
+
shift_down(min)
|
438
|
+
end
|
439
|
+
end
|
440
|
+
end
|
441
|
+
end
|
442
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# source: hnsw.proto
|
3
|
+
|
4
|
+
require "google/protobuf"
|
5
|
+
|
6
|
+
Google::Protobuf::DescriptorPool.generated_pool.build do
|
7
|
+
add_file("hnsw.proto", syntax: :proto3) do
|
8
|
+
add_message "HNSWGraphNode" do
|
9
|
+
optional :id, :string, 1
|
10
|
+
repeated :vector, :float, 2
|
11
|
+
optional :level, :int32, 3
|
12
|
+
repeated :neighbors, :string, 4
|
13
|
+
end
|
14
|
+
add_message "HNSWGraph" do
|
15
|
+
optional :entrypoint_id, :string, 1
|
16
|
+
optional :max_level, :int32, 2
|
17
|
+
optional :similarity_metric, :string, 3
|
18
|
+
optional :dimensions, :int32, 4
|
19
|
+
optional :m, :int32, 5
|
20
|
+
optional :ef, :int32, 6
|
21
|
+
repeated :nodes, :message, 7, "HNSWGraphNode"
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
HNSWGraphNode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraphNode").msgclass
|
27
|
+
HNSWGraph = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraph").msgclass
|