roseflow 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. checksums.yaml +4 -4
  2. data/.standard.yml +3 -1
  3. data/CHANGELOG.md +2 -2
  4. data/Gemfile +2 -0
  5. data/examples/github-repo-chat/lib/actions/clone_and_load_repository.rb +52 -0
  6. data/examples/github-repo-chat/lib/actions/create_prompt.rb +15 -0
  7. data/examples/github-repo-chat/lib/actions/embed_repository.rb +35 -0
  8. data/examples/github-repo-chat/lib/actions/initialize_vector_store.rb +40 -0
  9. data/examples/github-repo-chat/lib/actions/interact_with_model.rb +29 -0
  10. data/examples/github-repo-chat/lib/actions/load_documents_to_database.rb +0 -0
  11. data/examples/github-repo-chat/lib/actions/split_files_to_documents.rb +55 -0
  12. data/examples/github-repo-chat/lib/document_database.rb +0 -0
  13. data/examples/github-repo-chat/lib/github_chat_prompt.rb +24 -0
  14. data/examples/github-repo-chat/lib/github_repository_chat.rb +12 -0
  15. data/examples/github-repo-chat/lib/interactions/ask_llm.rb +31 -0
  16. data/examples/github-repo-chat/lib/interactions/github_repository_chat.rb +36 -0
  17. data/examples/github-repo-chat/lib/interactions/load_files_to_document_database.rb +18 -0
  18. data/examples/github-repo-chat/lib/interactions/load_repository.rb +20 -0
  19. data/examples/github-repo-chat/lib/interactions/prepare_vector_store.rb +21 -0
  20. data/examples/github-repo-chat/lib/repository.rb +9 -0
  21. data/examples/github-repo-chat/lib/repository_file.rb +31 -0
  22. data/examples/github-repo-chat/spec/actions/clone_and_load_repository_spec.rb +28 -0
  23. data/examples/github-repo-chat/spec/actions/embed_repository_spec.rb +24 -0
  24. data/examples/github-repo-chat/spec/actions/initialize_vector_store_spec.rb +20 -0
  25. data/examples/github-repo-chat/spec/actions/load_files_to_document_database_spec.rb +23 -0
  26. data/examples/github-repo-chat/spec/fixtures/ulid-ruby.zip +0 -0
  27. data/examples/github-repo-chat/spec/github_repository_chat_spec.rb +16 -0
  28. data/examples/github-repo-chat/spec/interactions/prepare_vector_store_spec.rb +4 -0
  29. data/examples/github-repo-chat/spec/spec_helper.rb +12 -0
  30. data/lib/roseflow/action.rb +13 -0
  31. data/lib/roseflow/actions/ai/resolve_model.rb +27 -0
  32. data/lib/roseflow/actions/ai/resolve_provider.rb +31 -0
  33. data/lib/roseflow/ai/model.rb +19 -0
  34. data/lib/roseflow/ai/provider.rb +30 -0
  35. data/lib/roseflow/chat/dialogue.rb +80 -0
  36. data/lib/roseflow/chat/exchange.rb +12 -0
  37. data/lib/roseflow/chat/message.rb +39 -0
  38. data/lib/roseflow/chat/personality.rb +10 -0
  39. data/lib/roseflow/embeddings/embedding.rb +26 -0
  40. data/lib/roseflow/finite_machine.rb +298 -0
  41. data/lib/roseflow/interaction/with_http_api.rb +10 -0
  42. data/lib/roseflow/interaction.rb +14 -0
  43. data/lib/roseflow/interaction_context.rb +10 -0
  44. data/lib/roseflow/interactions/ai/initialize_llm.rb +26 -0
  45. data/lib/roseflow/primitives/vector.rb +19 -0
  46. data/lib/roseflow/prompt.rb +17 -0
  47. data/lib/roseflow/text/completion.rb +16 -0
  48. data/lib/roseflow/text/recursive_character_splitter.rb +43 -0
  49. data/lib/roseflow/text/sentence_splitter.rb +42 -0
  50. data/lib/roseflow/text/splitter.rb +18 -0
  51. data/lib/roseflow/text/tokenized_text.rb +20 -0
  52. data/lib/roseflow/text/word_splitter.rb +14 -0
  53. data/lib/roseflow/tokenizer.rb +13 -0
  54. data/lib/roseflow/types.rb +9 -0
  55. data/lib/roseflow/vector_stores/base.rb +39 -0
  56. data/lib/roseflow/vector_stores/hnsw.proto +18 -0
  57. data/lib/roseflow/vector_stores/hnsw_memory_store.rb +442 -0
  58. data/lib/roseflow/vector_stores/hnsw_pb.rb +27 -0
  59. data/lib/roseflow/vector_stores/type/vector.rb +38 -0
  60. data/lib/roseflow/vector_stores/vector.rb +19 -0
  61. data/lib/roseflow/version.rb +12 -1
  62. data/lib/roseflow.rb +10 -1
  63. data/roseflow.gemspec +53 -0
  64. metadata +274 -7
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "roseflow/text/splitter"
4
+
5
+ module Roseflow
6
+ module Text
7
+ class RecursiveCharacterSplitter < Splitter
8
+ SEPARATORS = ["\n\n", "\n", " ", ""]
9
+
10
+ def initialize(separators = nil, **kwargs)
11
+ super(**kwargs)
12
+ @separators = separators || SEPARATORS
13
+ end
14
+
15
+ attr_reader :chunk_size, :chunk_overlap
16
+
17
+ def split(text)
18
+ segments = text.split(find_separator(text))
19
+ current_size = 0
20
+ results = [[]]
21
+
22
+ segments.each do |segment|
23
+ if current_size + segment.size > chunk_size
24
+ overlap = [results.last.last(chunk_overlap), segment].flatten
25
+ current_size = overlap.sum(&:size) + chunk_overlap
26
+ results << overlap
27
+ else
28
+ current_size += segment.size + results.last.size
29
+ results.last << segment
30
+ end
31
+ end
32
+
33
+ results.map { |r| r.join(" ") }
34
+ end
35
+
36
+ private
37
+
38
+ def find_separator(text)
39
+ @separators.find { |separator| text.include?(separator) } || @separators.last
40
+ end
41
+ end # RecursiveCharacterSplitter
42
+ end # Text
43
+ end # Roseflow
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "pragmatic_segmenter"
4
+ require "roseflow/text/splitter"
5
+
6
+ module Roseflow
7
+ module Text
8
+ class SentenceSplitter < Splitter
9
+ def initialize(language: "en", **kwargs)
10
+ super(**kwargs)
11
+ @language = language
12
+ end
13
+
14
+ attr_reader :chunk_size, :chunk_overlap
15
+
16
+ def segmenter(text)
17
+ @segmenter ||= PragmaticSegmenter::Segmenter.new(text: text, language: @language)
18
+ end
19
+
20
+ def split(text)
21
+ segments = segmenter(text).segment
22
+ current_size = 0
23
+ results = [[]]
24
+
25
+ segments.each do |segment|
26
+ if current_size + segment.size > chunk_size
27
+ overlap = [results.last.last(chunk_overlap), segment].flatten
28
+ current_size = overlap.sum(&:size) + chunk_overlap
29
+ results << overlap
30
+ else
31
+ current_size += segment.size + results.last.size
32
+ results.last << segment
33
+ end
34
+ end
35
+
36
+ results.map { |r| r.join(" ") }
37
+ end
38
+
39
+ private
40
+ end
41
+ end
42
+ end
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Roseflow
4
+ module Text
5
+ class Splitter
6
+ def initialize(chunk_size:, chunk_overlap:)
7
+ raise ArgumentError, "chunk overlap cannot exceed chunk size" if chunk_overlap > chunk_size
8
+
9
+ @chunk_size = chunk_size
10
+ @chunk_overlap = chunk_overlap
11
+ end
12
+
13
+ def split(text)
14
+ raise NotImplementedError, "this class must be extended and the #split method implemented"
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,20 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "dry-struct"
4
+
5
+ module Roseflow
6
+ module Text
7
+ class TokenizedText < Dry::Struct
8
+ attribute :text, Types::String
9
+ attribute :tokens, Types::Array.of(Types::Integer)
10
+
11
+ def token_count
12
+ tokens.count
13
+ end
14
+
15
+ def to_s
16
+ text
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,14 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "roseflow/text/recursive_character_splitter"
4
+
5
+ module Roseflow
6
+ module Text
7
+ class WordSplitter < RecursiveCharacterSplitter
8
+ def initialize(separators = [" "], **kwargs)
9
+ super(**kwargs)
10
+ @separators = separators
11
+ end
12
+ end # WordSplitter
13
+ end # Text
14
+ end # Roseflow
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Roseflow
4
+ class Tokenizer
5
+ def encode(input)
6
+ raise NotImplementedError, "this class must be extended and the #encode method implemented"
7
+ end
8
+
9
+ def decode(input)
10
+ raise NotImplementedError, "this class must be extended and the #decode method implemented"
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,9 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "dry-struct"
4
+
5
+ module Types
6
+ include Dry.Types()
7
+
8
+ Number = Types::Float | Types::Integer
9
+ end
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Roseflow
4
+ module VectorStores
5
+ class Base
6
+ def has_embeddings?(klass)
7
+ false
8
+ end
9
+
10
+ def list_vectors(namespace = nil)
11
+ raise NotImplementedError, "You must implement the #list_vectors method in your vector store"
12
+ end
13
+
14
+ def build_vector(id, attrs)
15
+ raise NotImplementedError, "You must implement the #build_vector method in your vector store"
16
+ end
17
+
18
+ def create_vector(vector)
19
+ raise NotImplementedError, "You must implement the #create_vector method in your vector store"
20
+ end
21
+
22
+ def delete_vector(name)
23
+ raise NotImplementedError, "You must implement the #delete_vector method in your vector store"
24
+ end
25
+
26
+ def update_vector(vector)
27
+ raise NotImplementedError, "You must implement the #update_vector method in your vector store"
28
+ end
29
+
30
+ def query(query)
31
+ raise NotImplementedError, "You must implement the #query method in your vector store"
32
+ end
33
+
34
+ def find(query)
35
+ raise NotImplementedError, "You must implement the #find method in your vector store"
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,18 @@
1
+ syntax = "proto3";
2
+
3
+ message HNSWGraphNode {
4
+ string id = 1;
5
+ repeated float vector = 2;
6
+ int32 level = 3;
7
+ repeated string neighbors = 4;
8
+ }
9
+
10
+ message HNSWGraph {
11
+ string entrypoint_id = 1;
12
+ int32 max_level = 2;
13
+ string similarity_metric = 3;
14
+ int32 dimensions = 4;
15
+ int32 m = 5;
16
+ int32 ef = 6;
17
+ repeated HNSWGraphNode nodes = 7;
18
+ }
@@ -0,0 +1,442 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ulid"
4
+ require "roseflow/vector_stores/hnsw_pb"
5
+
6
+ module Roseflow
7
+ module VectorStores
8
+ UnsupportedSimilarityMetricError = Class.new(StandardError)
9
+
10
+ # HNSWMemoryStore is an in-memory vector store that implements
11
+ # the HNSW algorithm.
12
+ class HNSWMemoryStore < Base
13
+ PROBABILITY_FACTORS = [
14
+ 0.5,
15
+ 1 / Math::E,
16
+ ].freeze
17
+
18
+ # Initializes a new HNSWMemoryStore with the specified
19
+ # similarity metric, dimensions, m and ef.
20
+ #
21
+ # @param similarity_metric [Symbol] the similarity metric to use
22
+ # @param dimensions [Integer] the number of dimensions of the vectors
23
+ # @param m [Integer] the number of neighbors to consider when adding a node
24
+ # @param ef [Integer] the number of neighbors to consider when searching
25
+ # @raise [UnsupportedSimilarityMetricError] if the similarity metric is not supported
26
+ # @return [HNSWMemoryStore] the new HNSWMemoryStore
27
+ def initialize(similarity_metric, dimensions, m, ef)
28
+ @similarity_metric = similarity_metric
29
+ @dimensions = dimensions
30
+ @m = m
31
+ @ef = ef
32
+ @max_level = 0
33
+ @entrypoint = nil
34
+ @nodes = {}
35
+ end
36
+
37
+ delegate :size, to: :@nodes
38
+
39
+ attr_accessor :nodes
40
+
41
+ # Adds a new node to the vector store.
42
+ #
43
+ # @param node_id [String] the ID of the node
44
+ # @param vector [Array<Float>] the vector of the node
45
+ # @return [HNSWNode] the new node
46
+ def add_node(node_id, vector)
47
+ level = get_random_level
48
+ node = HNSWNode.new(node_id, vector, level, @m)
49
+
50
+ if @entrypoint.nil?
51
+ @entrypoint = node
52
+ return @nodes[node_id] = node
53
+ end
54
+
55
+ update_max_level(level)
56
+ current_node = search_level(vector, @entrypoint, @max_level)
57
+
58
+ @max_level.downto(0) do |i|
59
+ if i <= level
60
+ neighbors = find_neighbors(current_node, vector, i)
61
+ update_neighbors(node, neighbors, vector, i)
62
+ end
63
+
64
+ current_node = search_level(vector, @entrypoint, i - 1) if i > 0
65
+ end
66
+
67
+ @nodes[node_id] = node
68
+ end
69
+
70
+ alias_method :create_vector, :add_node
71
+
72
+ # Deletes a node from the vector store.
73
+ #
74
+ # @param node_id [String] the ID of the node
75
+ # @return [HNSWNode] the deleted node
76
+ def delete_node(node_id)
77
+ @nodes.delete(node_id)
78
+ end
79
+
80
+ alias_method :delete_vector, :delete_node
81
+
82
+ # Finds a node in the vector store.
83
+ #
84
+ # @param node_id [String] the ID of the node
85
+ # @return [HNSWNode] the found node
86
+ # @return [nil] if the node was not found
87
+ def find(node_id)
88
+ @nodes[node_id]
89
+ end
90
+
91
+ # Serializes the vector store to a binary string.
92
+ def serialize
93
+ graph = HNSWGraph.new(
94
+ entrypoint_id: @entrypoint.id,
95
+ max_level: @max_level,
96
+ similarity_metric: @similarity_metric,
97
+ dimensions: @dimensions,
98
+ m: @m,
99
+ ef: @ef,
100
+ nodes: @nodes.values.map do |node|
101
+ HNSWGraphNode.new(
102
+ id: node.id,
103
+ vector: node.vector,
104
+ level: node.level,
105
+ neighbors: node.neighbors.flatten.compact.map(&:id),
106
+ )
107
+ end,
108
+ )
109
+
110
+ graph.to_proto
111
+ end
112
+
113
+ # Deserializes a binary string into a vector store.
114
+ def self.deserialize(serialized_data)
115
+ graph = HNSWGraph.decode(serialized_data)
116
+
117
+ hnsw = new(graph.similarity_metric, graph.dimensions, graph.m, graph.ef)
118
+
119
+ # Create nodes
120
+ graph.nodes.each do |node|
121
+ hnsw.nodes[node.id] = HNSWNode.new(node.id, node.vector, node.level, graph.m)
122
+ end
123
+
124
+ # Set neighbors
125
+ graph.nodes.each do |node|
126
+ neighbors = node.neighbors.each_slice(graph.m).to_a
127
+ neighbors.each_with_index do |neighbor_ids, level|
128
+ neighbor_ids.each_with_index do |neighbor_id, index|
129
+ hnsw.nodes[node.id].neighbors[level][index] = hnsw.nodes[neighbor_id] if hnsw.nodes.key?(neighbor_id)
130
+ end
131
+ end
132
+ end
133
+
134
+ hnsw.instance_variable_set(:@entrypoint, hnsw.nodes[graph.entrypoint_id])
135
+ hnsw.instance_variable_set(:@max_level, graph.max_level)
136
+
137
+ hnsw
138
+ end
139
+
140
+ # Finds the nearest neighbors of a vector.
141
+ def find_neighbors(node, query, level)
142
+ search_knn(node, query, @m, level)
143
+ end
144
+
145
+ # Updates the neighbors of a node.
146
+ def update_neighbors(node, neighbors, query, level)
147
+ node.neighbors[level] = neighbors[0, @m]
148
+
149
+ neighbors.each do |neighbor|
150
+ n_distance = distance(neighbor.vector, query)
151
+ furthest_neighbor_index = neighbor.neighbors[level].index { |n| n.nil? || n_distance < distance(neighbor.vector, n.vector) }
152
+ next unless furthest_neighbor_index
153
+
154
+ neighbor.neighbors[level].insert(furthest_neighbor_index, node)
155
+ neighbor.neighbors[level].pop if neighbor.neighbors[level].size > @m
156
+ end
157
+ end
158
+
159
+ # Updates maximum level of the graph.
160
+ def update_max_level(level)
161
+ @max_level = level if level > @max_level
162
+ end
163
+
164
+ # Finds the k nearest neighbors of a vector.
165
+ def nearest_neighbors(query, k)
166
+ return [] unless @entrypoint
167
+ entry_point = @entrypoint
168
+ (0..@max_level).reverse_each do |level|
169
+ entry_point = search_level(query, entry_point, level)
170
+ end
171
+ search_knn(entry_point, query, k, 0)
172
+ end
173
+
174
+ def search_level(query, entry_point, level)
175
+ current = entry_point
176
+ best_distance = distance(query, current.vector)
177
+
178
+ loop do
179
+ closest_neighbor, closest_distance = find_closest_neighbor(query, current.neighbors[level])
180
+
181
+ if closest_neighbor && closest_distance < best_distance
182
+ best_distance = closest_distance
183
+ current = closest_neighbor
184
+ else
185
+ break
186
+ end
187
+ end
188
+
189
+ current
190
+ end
191
+
192
+ def find_closest_neighbor(query, neighbors)
193
+ closest_neighbor = nil
194
+ closest_distance = Float::INFINITY
195
+
196
+ neighbors.each do |neighbor|
197
+ next unless neighbor
198
+ distance = distance(query, neighbor.vector)
199
+ if distance < closest_distance
200
+ closest_distance = distance
201
+ closest_neighbor = neighbor
202
+ end
203
+ end
204
+
205
+ [closest_neighbor, closest_distance]
206
+ end
207
+
208
+ # Finds the k nearest neighbors of a vector.
209
+ def search_knn(entry_point, query, k, level)
210
+ visited = Set.new
211
+ candidates = Set.new([entry_point])
212
+ result = []
213
+
214
+ while candidates.size > 0
215
+ closest = find_closest_candidate(candidates, query)
216
+ candidates.delete(closest)
217
+ visited.add(closest.id)
218
+
219
+ result = update_result(result, closest, query, k)
220
+
221
+ break if termination_condition_met?(result, closest, query, k)
222
+
223
+ add_neighbors_to_candidates(closest, level, visited, candidates)
224
+ end
225
+
226
+ result
227
+ end
228
+
229
+ def find_closest_candidate(candidates, query)
230
+ candidates.min_by { |c| distance(query, c.vector) }
231
+ end
232
+
233
+ def update_result(result, candidate, query, k)
234
+ if result.size < k
235
+ result.push(candidate)
236
+ else
237
+ furthest_result = result.max_by { |r| distance(query, r.vector) }
238
+ closest_distance = distance(query, candidate.vector)
239
+ furthest_result_distance = distance(query, furthest_result.vector)
240
+
241
+ if closest_distance < furthest_result_distance
242
+ result.delete(furthest_result)
243
+ result.push(candidate)
244
+ end
245
+ end
246
+ result
247
+ end
248
+
249
+ def termination_condition_met?(result, closest, query, k)
250
+ return false if result.size < k
251
+
252
+ furthest_result_distance = distance(query, result.max_by { |r| distance(query, r.vector) }.vector)
253
+ closest_distance = distance(query, closest.vector)
254
+
255
+ closest_distance >= furthest_result_distance
256
+ end
257
+
258
+ def add_neighbors_to_candidates(closest, level, visited, candidates)
259
+ closest.neighbors[level].each do |neighbor|
260
+ next unless neighbor
261
+ next if visited.include?(neighbor.id)
262
+ candidates.add(neighbor)
263
+ end
264
+ end
265
+
266
+ # Calculates the distance between two vectors.
267
+ def distance(from, to)
268
+ case @similarity_metric.to_sym
269
+ when :euclidean
270
+ euclidean_distance(from, to)
271
+ when :cosine
272
+ cosine_distance(from, to)
273
+ else
274
+ raise UnsupportedSimilarityMetricError, "Similarity metric #{@similarity_metric} is not supported"
275
+ end
276
+ end
277
+
278
+ # Calculates the euclidean distance between two vectors.
279
+ def euclidean_distance(from, to)
280
+ e_distance = 0
281
+ from.each_with_index do |value, index|
282
+ e_distance += (value - to[index]) ** 2
283
+ end
284
+
285
+ Math.sqrt(e_distance)
286
+ end
287
+
288
+ # Calculates the cosine distance between two vectors.
289
+ def cosine_distance(from, to)
290
+ dot_product = 0
291
+ norm_from = 0
292
+ norm_to = 0
293
+
294
+ from.each_with_index do |value, index|
295
+ dot_product += value * to[index]
296
+ norm_from += value ** 2
297
+ norm_to += to[index] ** 2
298
+ end
299
+
300
+ 1 - (dot_product / (Math.sqrt(norm_from) * Math.sqrt(norm_to)))
301
+ end
302
+
303
+ # Returns a random level for a node.
304
+ def get_random_level
305
+ level = 0
306
+ while rand < PROBABILITY_FACTORS[0] && level < @max_level
307
+ level += 1
308
+ end
309
+ level
310
+ end
311
+
312
+ # HNSW vector store node.
313
+ class HNSWNode
314
+ attr_reader :id, :vector
315
+ attr_accessor :level, :neighbors
316
+
317
+ # Initializes a new node.
318
+ #
319
+ # @param id [String] the node ID (ULID)
320
+ # @param vector [Array] the node vector
321
+ # @param level [Integer] the node level
322
+ # @param m [Integer] the number of neighbors
323
+ # @return [HNSWNode] the node
324
+ def initialize(id, vector, level, m)
325
+ @id = id
326
+ @vector = vector
327
+ @level = level
328
+ @neighbors = Array.new(level + 1) { Array.new(m) }
329
+ end
330
+ end
331
+
332
+ # BoundedPriorityQueue is a data structure that keeps a priority queue
333
+ # of a bounded size. It maintains the top-k elements with the smallest
334
+ # priorities. It uses an underlying PriorityQueue to store elements.
335
+ class BoundedPriorityQueue
336
+ def initialize(max_size)
337
+ @max_size = max_size
338
+ @queue = PriorityQueue.new
339
+ end
340
+
341
+ def size
342
+ @queue.size
343
+ end
344
+
345
+ # Inserts an item into the BoundedPriorityQueue. If the queue is full
346
+ # and the new item has a smaller priority than the item with the
347
+ # highest priority, the highest priority item is removed and the new
348
+ # item is added.
349
+ def push(item)
350
+ if size < @max_size
351
+ @queue.push(item)
352
+ elsif item[0] < @queue.peek[0]
353
+ @queue.pop
354
+ @queue.push(item)
355
+ end
356
+ end
357
+
358
+ # Returns the item with the smallest priority without removing it from
359
+ # the BoundedPriorityQueue.
360
+ def peek
361
+ @queue.peek
362
+ end
363
+
364
+ def to_a
365
+ @queue.to_a
366
+ end
367
+ end
368
+
369
+ # PriorityQueue is a data structure that keeps elements ordered by priority.
370
+ # It supports inserting elements, removing the element with the smallest
371
+ # priority, and peeking at the element with the smallest priority. It uses
372
+ # a binary heap as the underlying data structure.
373
+ class PriorityQueue
374
+ def initialize
375
+ @elements = []
376
+ end
377
+
378
+ def size
379
+ @elements.size
380
+ end
381
+
382
+ def empty?
383
+ @elements.empty?
384
+ end
385
+
386
+ def push(item)
387
+ @elements << item
388
+ shift_up(@elements.size - 1)
389
+ end
390
+
391
+ # Removes and returns the element with the smallest priority.
392
+ # Returns nil if the PriorityQueue is empty.
393
+ def pop
394
+ return if empty?
395
+
396
+ swap(0, @elements.size - 1)
397
+ element = @elements.pop
398
+ shift_down(0)
399
+ element
400
+ end
401
+
402
+ # Returns the element with the smallest priority without removing it
403
+ # from the PriorityQueue.
404
+ def peek
405
+ @elements.first
406
+ end
407
+
408
+ def to_a
409
+ @elements.dup
410
+ end
411
+
412
+ private
413
+
414
+ def swap(i, j)
415
+ @elements[i], @elements[j] = @elements[j], @elements[i]
416
+ end
417
+
418
+ def shift_up(i)
419
+ parent = (i - 1) / 2
420
+ return if i <= 0 || @elements[parent][0] <= @elements[i][0]
421
+
422
+ swap(i, parent)
423
+ shift_up(parent)
424
+ end
425
+
426
+ def shift_down(i)
427
+ left_child = 2 * i + 1
428
+ right_child = 2 * i + 2
429
+
430
+ min = i
431
+ min = left_child if left_child < size && @elements[left_child][0] < @elements[min][0]
432
+ min = right_child if right_child < size && @elements[right_child][0] < @elements[min][0]
433
+
434
+ return if min == i
435
+
436
+ swap(i, min)
437
+ shift_down(min)
438
+ end
439
+ end
440
+ end
441
+ end
442
+ end
@@ -0,0 +1,27 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: hnsw.proto
3
+
4
+ require "google/protobuf"
5
+
6
+ Google::Protobuf::DescriptorPool.generated_pool.build do
7
+ add_file("hnsw.proto", syntax: :proto3) do
8
+ add_message "HNSWGraphNode" do
9
+ optional :id, :string, 1
10
+ repeated :vector, :float, 2
11
+ optional :level, :int32, 3
12
+ repeated :neighbors, :string, 4
13
+ end
14
+ add_message "HNSWGraph" do
15
+ optional :entrypoint_id, :string, 1
16
+ optional :max_level, :int32, 2
17
+ optional :similarity_metric, :string, 3
18
+ optional :dimensions, :int32, 4
19
+ optional :m, :int32, 5
20
+ optional :ef, :int32, 6
21
+ repeated :nodes, :message, 7, "HNSWGraphNode"
22
+ end
23
+ end
24
+ end
25
+
26
+ HNSWGraphNode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraphNode").msgclass
27
+ HNSWGraph = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraph").msgclass