roseflow 0.0.1 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (64) hide show
  1. checksums.yaml +4 -4
  2. data/.standard.yml +3 -1
  3. data/CHANGELOG.md +2 -2
  4. data/Gemfile +2 -0
  5. data/examples/github-repo-chat/lib/actions/clone_and_load_repository.rb +52 -0
  6. data/examples/github-repo-chat/lib/actions/create_prompt.rb +15 -0
  7. data/examples/github-repo-chat/lib/actions/embed_repository.rb +35 -0
  8. data/examples/github-repo-chat/lib/actions/initialize_vector_store.rb +40 -0
  9. data/examples/github-repo-chat/lib/actions/interact_with_model.rb +29 -0
  10. data/examples/github-repo-chat/lib/actions/load_documents_to_database.rb +0 -0
  11. data/examples/github-repo-chat/lib/actions/split_files_to_documents.rb +55 -0
  12. data/examples/github-repo-chat/lib/document_database.rb +0 -0
  13. data/examples/github-repo-chat/lib/github_chat_prompt.rb +24 -0
  14. data/examples/github-repo-chat/lib/github_repository_chat.rb +12 -0
  15. data/examples/github-repo-chat/lib/interactions/ask_llm.rb +31 -0
  16. data/examples/github-repo-chat/lib/interactions/github_repository_chat.rb +36 -0
  17. data/examples/github-repo-chat/lib/interactions/load_files_to_document_database.rb +18 -0
  18. data/examples/github-repo-chat/lib/interactions/load_repository.rb +20 -0
  19. data/examples/github-repo-chat/lib/interactions/prepare_vector_store.rb +21 -0
  20. data/examples/github-repo-chat/lib/repository.rb +9 -0
  21. data/examples/github-repo-chat/lib/repository_file.rb +31 -0
  22. data/examples/github-repo-chat/spec/actions/clone_and_load_repository_spec.rb +28 -0
  23. data/examples/github-repo-chat/spec/actions/embed_repository_spec.rb +24 -0
  24. data/examples/github-repo-chat/spec/actions/initialize_vector_store_spec.rb +20 -0
  25. data/examples/github-repo-chat/spec/actions/load_files_to_document_database_spec.rb +23 -0
  26. data/examples/github-repo-chat/spec/fixtures/ulid-ruby.zip +0 -0
  27. data/examples/github-repo-chat/spec/github_repository_chat_spec.rb +16 -0
  28. data/examples/github-repo-chat/spec/interactions/prepare_vector_store_spec.rb +4 -0
  29. data/examples/github-repo-chat/spec/spec_helper.rb +12 -0
  30. data/lib/roseflow/action.rb +13 -0
  31. data/lib/roseflow/actions/ai/resolve_model.rb +27 -0
  32. data/lib/roseflow/actions/ai/resolve_provider.rb +31 -0
  33. data/lib/roseflow/ai/model.rb +19 -0
  34. data/lib/roseflow/ai/provider.rb +30 -0
  35. data/lib/roseflow/chat/dialogue.rb +80 -0
  36. data/lib/roseflow/chat/exchange.rb +12 -0
  37. data/lib/roseflow/chat/message.rb +39 -0
  38. data/lib/roseflow/chat/personality.rb +10 -0
  39. data/lib/roseflow/embeddings/embedding.rb +26 -0
  40. data/lib/roseflow/finite_machine.rb +298 -0
  41. data/lib/roseflow/interaction/with_http_api.rb +10 -0
  42. data/lib/roseflow/interaction.rb +14 -0
  43. data/lib/roseflow/interaction_context.rb +10 -0
  44. data/lib/roseflow/interactions/ai/initialize_llm.rb +26 -0
  45. data/lib/roseflow/primitives/vector.rb +19 -0
  46. data/lib/roseflow/prompt.rb +17 -0
  47. data/lib/roseflow/text/completion.rb +16 -0
  48. data/lib/roseflow/text/recursive_character_splitter.rb +43 -0
  49. data/lib/roseflow/text/sentence_splitter.rb +42 -0
  50. data/lib/roseflow/text/splitter.rb +18 -0
  51. data/lib/roseflow/text/tokenized_text.rb +20 -0
  52. data/lib/roseflow/text/word_splitter.rb +14 -0
  53. data/lib/roseflow/tokenizer.rb +13 -0
  54. data/lib/roseflow/types.rb +9 -0
  55. data/lib/roseflow/vector_stores/base.rb +39 -0
  56. data/lib/roseflow/vector_stores/hnsw.proto +18 -0
  57. data/lib/roseflow/vector_stores/hnsw_memory_store.rb +442 -0
  58. data/lib/roseflow/vector_stores/hnsw_pb.rb +27 -0
  59. data/lib/roseflow/vector_stores/type/vector.rb +38 -0
  60. data/lib/roseflow/vector_stores/vector.rb +19 -0
  61. data/lib/roseflow/version.rb +12 -1
  62. data/lib/roseflow.rb +10 -1
  63. data/roseflow.gemspec +53 -0
  64. metadata +274 -7
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "roseflow/text/splitter"
4
+
5
+ module Roseflow
6
+ module Text
7
+ class RecursiveCharacterSplitter < Splitter
8
+ SEPARATORS = ["\n\n", "\n", " ", ""]
9
+
10
+ def initialize(separators = nil, **kwargs)
11
+ super(**kwargs)
12
+ @separators = separators || SEPARATORS
13
+ end
14
+
15
+ attr_reader :chunk_size, :chunk_overlap
16
+
17
+ def split(text)
18
+ segments = text.split(find_separator(text))
19
+ current_size = 0
20
+ results = [[]]
21
+
22
+ segments.each do |segment|
23
+ if current_size + segment.size > chunk_size
24
+ overlap = [results.last.last(chunk_overlap), segment].flatten
25
+ current_size = overlap.sum(&:size) + chunk_overlap
26
+ results << overlap
27
+ else
28
+ current_size += segment.size + results.last.size
29
+ results.last << segment
30
+ end
31
+ end
32
+
33
+ results.map { |r| r.join(" ") }
34
+ end
35
+
36
+ private
37
+
38
+ def find_separator(text)
39
+ @separators.find { |separator| text.include?(separator) } || @separators.last
40
+ end
41
+ end # RecursiveCharacterSplitter
42
+ end # Text
43
+ end # Roseflow
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "pragmatic_segmenter"
4
+ require "roseflow/text/splitter"
5
+
6
+ module Roseflow
7
+ module Text
8
+ class SentenceSplitter < Splitter
9
+ def initialize(language: "en", **kwargs)
10
+ super(**kwargs)
11
+ @language = language
12
+ end
13
+
14
+ attr_reader :chunk_size, :chunk_overlap
15
+
16
+ def segmenter(text)
17
+ @segmenter ||= PragmaticSegmenter::Segmenter.new(text: text, language: @language)
18
+ end
19
+
20
+ def split(text)
21
+ segments = segmenter(text).segment
22
+ current_size = 0
23
+ results = [[]]
24
+
25
+ segments.each do |segment|
26
+ if current_size + segment.size > chunk_size
27
+ overlap = [results.last.last(chunk_overlap), segment].flatten
28
+ current_size = overlap.sum(&:size) + chunk_overlap
29
+ results << overlap
30
+ else
31
+ current_size += segment.size + results.last.size
32
+ results.last << segment
33
+ end
34
+ end
35
+
36
+ results.map { |r| r.join(" ") }
37
+ end
38
+
39
+ private
40
+ end
41
+ end
42
+ end
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Roseflow
4
+ module Text
5
+ class Splitter
6
+ def initialize(chunk_size:, chunk_overlap:)
7
+ raise ArgumentError, "chunk overlap cannot exceed chunk size" if chunk_overlap > chunk_size
8
+
9
+ @chunk_size = chunk_size
10
+ @chunk_overlap = chunk_overlap
11
+ end
12
+
13
+ def split(text)
14
+ raise NotImplementedError, "this class must be extended and the #split method implemented"
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,20 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "dry-struct"
4
+
5
+ module Roseflow
6
+ module Text
7
+ class TokenizedText < Dry::Struct
8
+ attribute :text, Types::String
9
+ attribute :tokens, Types::Array.of(Types::Integer)
10
+
11
+ def token_count
12
+ tokens.count
13
+ end
14
+
15
+ def to_s
16
+ text
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,14 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "roseflow/text/recursive_character_splitter"
4
+
5
+ module Roseflow
6
+ module Text
7
+ class WordSplitter < RecursiveCharacterSplitter
8
+ def initialize(separators = [" "], **kwargs)
9
+ super(**kwargs)
10
+ @separators = separators
11
+ end
12
+ end # WordSplitter
13
+ end # Text
14
+ end # Roseflow
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Roseflow
4
+ class Tokenizer
5
+ def encode(input)
6
+ raise NotImplementedError, "this class must be extended and the #encode method implemented"
7
+ end
8
+
9
+ def decode(input)
10
+ raise NotImplementedError, "this class must be extended and the #decode method implemented"
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,9 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "dry-struct"
4
+
5
+ module Types
6
+ include Dry.Types()
7
+
8
+ Number = Types::Float | Types::Integer
9
+ end
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Roseflow
4
+ module VectorStores
5
+ class Base
6
+ def has_embeddings?(klass)
7
+ false
8
+ end
9
+
10
+ def list_vectors(namespace = nil)
11
+ raise NotImplementedError, "You must implement the #list_vectors method in your vector store"
12
+ end
13
+
14
+ def build_vector(id, attrs)
15
+ raise NotImplementedError, "You must implement the #build_vector method in your vector store"
16
+ end
17
+
18
+ def create_vector(vector)
19
+ raise NotImplementedError, "You must implement the #create_vector method in your vector store"
20
+ end
21
+
22
+ def delete_vector(name)
23
+ raise NotImplementedError, "You must implement the #delete_vector method in your vector store"
24
+ end
25
+
26
+ def update_vector(vector)
27
+ raise NotImplementedError, "You must implement the #update_vector method in your vector store"
28
+ end
29
+
30
+ def query(query)
31
+ raise NotImplementedError, "You must implement the #query method in your vector store"
32
+ end
33
+
34
+ def find(query)
35
+ raise NotImplementedError, "You must implement the #find method in your vector store"
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,18 @@
1
+ syntax = "proto3";
2
+
3
+ message HNSWGraphNode {
4
+ string id = 1;
5
+ repeated float vector = 2;
6
+ int32 level = 3;
7
+ repeated string neighbors = 4;
8
+ }
9
+
10
+ message HNSWGraph {
11
+ string entrypoint_id = 1;
12
+ int32 max_level = 2;
13
+ string similarity_metric = 3;
14
+ int32 dimensions = 4;
15
+ int32 m = 5;
16
+ int32 ef = 6;
17
+ repeated HNSWGraphNode nodes = 7;
18
+ }
@@ -0,0 +1,442 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "ulid"
4
+ require "roseflow/vector_stores/hnsw_pb"
5
+
6
+ module Roseflow
7
+ module VectorStores
8
+ UnsupportedSimilarityMetricError = Class.new(StandardError)
9
+
10
+ # HNSWMemoryStore is an in-memory vector store that implements
11
+ # the HNSW algorithm.
12
+ class HNSWMemoryStore < Base
13
+ PROBABILITY_FACTORS = [
14
+ 0.5,
15
+ 1 / Math::E,
16
+ ].freeze
17
+
18
+ # Initializes a new HNSWMemoryStore with the specified
19
+ # similarity metric, dimensions, m and ef.
20
+ #
21
+ # @param similarity_metric [Symbol] the similarity metric to use
22
+ # @param dimensions [Integer] the number of dimensions of the vectors
23
+ # @param m [Integer] the number of neighbors to consider when adding a node
24
+ # @param ef [Integer] the number of neighbors to consider when searching
25
+ # @raise [UnsupportedSimilarityMetricError] if the similarity metric is not supported
26
+ # @return [HNSWMemoryStore] the new HNSWMemoryStore
27
+ def initialize(similarity_metric, dimensions, m, ef)
28
+ @similarity_metric = similarity_metric
29
+ @dimensions = dimensions
30
+ @m = m
31
+ @ef = ef
32
+ @max_level = 0
33
+ @entrypoint = nil
34
+ @nodes = {}
35
+ end
36
+
37
+ delegate :size, to: :@nodes
38
+
39
+ attr_accessor :nodes
40
+
41
+ # Adds a new node to the vector store.
42
+ #
43
+ # @param node_id [String] the ID of the node
44
+ # @param vector [Array<Float>] the vector of the node
45
+ # @return [HNSWNode] the new node
46
+ def add_node(node_id, vector)
47
+ level = get_random_level
48
+ node = HNSWNode.new(node_id, vector, level, @m)
49
+
50
+ if @entrypoint.nil?
51
+ @entrypoint = node
52
+ return @nodes[node_id] = node
53
+ end
54
+
55
+ update_max_level(level)
56
+ current_node = search_level(vector, @entrypoint, @max_level)
57
+
58
+ @max_level.downto(0) do |i|
59
+ if i <= level
60
+ neighbors = find_neighbors(current_node, vector, i)
61
+ update_neighbors(node, neighbors, vector, i)
62
+ end
63
+
64
+ current_node = search_level(vector, @entrypoint, i - 1) if i > 0
65
+ end
66
+
67
+ @nodes[node_id] = node
68
+ end
69
+
70
+ alias_method :create_vector, :add_node
71
+
72
+ # Deletes a node from the vector store.
73
+ #
74
+ # @param node_id [String] the ID of the node
75
+ # @return [HNSWNode] the deleted node
76
+ def delete_node(node_id)
77
+ @nodes.delete(node_id)
78
+ end
79
+
80
+ alias_method :delete_vector, :delete_node
81
+
82
+ # Finds a node in the vector store.
83
+ #
84
+ # @param node_id [String] the ID of the node
85
+ # @return [HNSWNode] the found node
86
+ # @return [nil] if the node was not found
87
+ def find(node_id)
88
+ @nodes[node_id]
89
+ end
90
+
91
+ # Serializes the vector store to a binary string.
92
+ def serialize
93
+ graph = HNSWGraph.new(
94
+ entrypoint_id: @entrypoint.id,
95
+ max_level: @max_level,
96
+ similarity_metric: @similarity_metric,
97
+ dimensions: @dimensions,
98
+ m: @m,
99
+ ef: @ef,
100
+ nodes: @nodes.values.map do |node|
101
+ HNSWGraphNode.new(
102
+ id: node.id,
103
+ vector: node.vector,
104
+ level: node.level,
105
+ neighbors: node.neighbors.flatten.compact.map(&:id),
106
+ )
107
+ end,
108
+ )
109
+
110
+ graph.to_proto
111
+ end
112
+
113
+ # Deserializes a binary string into a vector store.
114
+ def self.deserialize(serialized_data)
115
+ graph = HNSWGraph.decode(serialized_data)
116
+
117
+ hnsw = new(graph.similarity_metric, graph.dimensions, graph.m, graph.ef)
118
+
119
+ # Create nodes
120
+ graph.nodes.each do |node|
121
+ hnsw.nodes[node.id] = HNSWNode.new(node.id, node.vector, node.level, graph.m)
122
+ end
123
+
124
+ # Set neighbors
125
+ graph.nodes.each do |node|
126
+ neighbors = node.neighbors.each_slice(graph.m).to_a
127
+ neighbors.each_with_index do |neighbor_ids, level|
128
+ neighbor_ids.each_with_index do |neighbor_id, index|
129
+ hnsw.nodes[node.id].neighbors[level][index] = hnsw.nodes[neighbor_id] if hnsw.nodes.key?(neighbor_id)
130
+ end
131
+ end
132
+ end
133
+
134
+ hnsw.instance_variable_set(:@entrypoint, hnsw.nodes[graph.entrypoint_id])
135
+ hnsw.instance_variable_set(:@max_level, graph.max_level)
136
+
137
+ hnsw
138
+ end
139
+
140
+ # Finds the nearest neighbors of a vector.
141
+ def find_neighbors(node, query, level)
142
+ search_knn(node, query, @m, level)
143
+ end
144
+
145
+ # Updates the neighbors of a node.
146
+ def update_neighbors(node, neighbors, query, level)
147
+ node.neighbors[level] = neighbors[0, @m]
148
+
149
+ neighbors.each do |neighbor|
150
+ n_distance = distance(neighbor.vector, query)
151
+ furthest_neighbor_index = neighbor.neighbors[level].index { |n| n.nil? || n_distance < distance(neighbor.vector, n.vector) }
152
+ next unless furthest_neighbor_index
153
+
154
+ neighbor.neighbors[level].insert(furthest_neighbor_index, node)
155
+ neighbor.neighbors[level].pop if neighbor.neighbors[level].size > @m
156
+ end
157
+ end
158
+
159
+ # Updates maximum level of the graph.
160
+ def update_max_level(level)
161
+ @max_level = level if level > @max_level
162
+ end
163
+
164
+ # Finds the k nearest neighbors of a vector.
165
+ def nearest_neighbors(query, k)
166
+ return [] unless @entrypoint
167
+ entry_point = @entrypoint
168
+ (0..@max_level).reverse_each do |level|
169
+ entry_point = search_level(query, entry_point, level)
170
+ end
171
+ search_knn(entry_point, query, k, 0)
172
+ end
173
+
174
+ def search_level(query, entry_point, level)
175
+ current = entry_point
176
+ best_distance = distance(query, current.vector)
177
+
178
+ loop do
179
+ closest_neighbor, closest_distance = find_closest_neighbor(query, current.neighbors[level])
180
+
181
+ if closest_neighbor && closest_distance < best_distance
182
+ best_distance = closest_distance
183
+ current = closest_neighbor
184
+ else
185
+ break
186
+ end
187
+ end
188
+
189
+ current
190
+ end
191
+
192
+ def find_closest_neighbor(query, neighbors)
193
+ closest_neighbor = nil
194
+ closest_distance = Float::INFINITY
195
+
196
+ neighbors.each do |neighbor|
197
+ next unless neighbor
198
+ distance = distance(query, neighbor.vector)
199
+ if distance < closest_distance
200
+ closest_distance = distance
201
+ closest_neighbor = neighbor
202
+ end
203
+ end
204
+
205
+ [closest_neighbor, closest_distance]
206
+ end
207
+
208
+ # Finds the k nearest neighbors of a vector.
209
+ def search_knn(entry_point, query, k, level)
210
+ visited = Set.new
211
+ candidates = Set.new([entry_point])
212
+ result = []
213
+
214
+ while candidates.size > 0
215
+ closest = find_closest_candidate(candidates, query)
216
+ candidates.delete(closest)
217
+ visited.add(closest.id)
218
+
219
+ result = update_result(result, closest, query, k)
220
+
221
+ break if termination_condition_met?(result, closest, query, k)
222
+
223
+ add_neighbors_to_candidates(closest, level, visited, candidates)
224
+ end
225
+
226
+ result
227
+ end
228
+
229
+ def find_closest_candidate(candidates, query)
230
+ candidates.min_by { |c| distance(query, c.vector) }
231
+ end
232
+
233
+ def update_result(result, candidate, query, k)
234
+ if result.size < k
235
+ result.push(candidate)
236
+ else
237
+ furthest_result = result.max_by { |r| distance(query, r.vector) }
238
+ closest_distance = distance(query, candidate.vector)
239
+ furthest_result_distance = distance(query, furthest_result.vector)
240
+
241
+ if closest_distance < furthest_result_distance
242
+ result.delete(furthest_result)
243
+ result.push(candidate)
244
+ end
245
+ end
246
+ result
247
+ end
248
+
249
+ def termination_condition_met?(result, closest, query, k)
250
+ return false if result.size < k
251
+
252
+ furthest_result_distance = distance(query, result.max_by { |r| distance(query, r.vector) }.vector)
253
+ closest_distance = distance(query, closest.vector)
254
+
255
+ closest_distance >= furthest_result_distance
256
+ end
257
+
258
+ def add_neighbors_to_candidates(closest, level, visited, candidates)
259
+ closest.neighbors[level].each do |neighbor|
260
+ next unless neighbor
261
+ next if visited.include?(neighbor.id)
262
+ candidates.add(neighbor)
263
+ end
264
+ end
265
+
266
+ # Calculates the distance between two vectors.
267
+ def distance(from, to)
268
+ case @similarity_metric.to_sym
269
+ when :euclidean
270
+ euclidean_distance(from, to)
271
+ when :cosine
272
+ cosine_distance(from, to)
273
+ else
274
+ raise UnsupportedSimilarityMetricError, "Similarity metric #{@similarity_metric} is not supported"
275
+ end
276
+ end
277
+
278
+ # Calculates the euclidean distance between two vectors.
279
+ def euclidean_distance(from, to)
280
+ e_distance = 0
281
+ from.each_with_index do |value, index|
282
+ e_distance += (value - to[index]) ** 2
283
+ end
284
+
285
+ Math.sqrt(e_distance)
286
+ end
287
+
288
+ # Calculates the cosine distance between two vectors.
289
+ def cosine_distance(from, to)
290
+ dot_product = 0
291
+ norm_from = 0
292
+ norm_to = 0
293
+
294
+ from.each_with_index do |value, index|
295
+ dot_product += value * to[index]
296
+ norm_from += value ** 2
297
+ norm_to += to[index] ** 2
298
+ end
299
+
300
+ 1 - (dot_product / (Math.sqrt(norm_from) * Math.sqrt(norm_to)))
301
+ end
302
+
303
+ # Returns a random level for a node.
304
+ def get_random_level
305
+ level = 0
306
+ while rand < PROBABILITY_FACTORS[0] && level < @max_level
307
+ level += 1
308
+ end
309
+ level
310
+ end
311
+
312
+ # HNSW vector store node.
313
+ class HNSWNode
314
+ attr_reader :id, :vector
315
+ attr_accessor :level, :neighbors
316
+
317
+ # Initializes a new node.
318
+ #
319
+ # @param id [String] the node ID (ULID)
320
+ # @param vector [Array] the node vector
321
+ # @param level [Integer] the node level
322
+ # @param m [Integer] the number of neighbors
323
+ # @return [HNSWNode] the node
324
+ def initialize(id, vector, level, m)
325
+ @id = id
326
+ @vector = vector
327
+ @level = level
328
+ @neighbors = Array.new(level + 1) { Array.new(m) }
329
+ end
330
+ end
331
+
332
+ # BoundedPriorityQueue is a data structure that keeps a priority queue
333
+ # of a bounded size. It maintains the top-k elements with the smallest
334
+ # priorities. It uses an underlying PriorityQueue to store elements.
335
+ class BoundedPriorityQueue
336
+ def initialize(max_size)
337
+ @max_size = max_size
338
+ @queue = PriorityQueue.new
339
+ end
340
+
341
+ def size
342
+ @queue.size
343
+ end
344
+
345
+ # Inserts an item into the BoundedPriorityQueue. If the queue is full
346
+ # and the new item has a smaller priority than the item with the
347
+ # highest priority, the highest priority item is removed and the new
348
+ # item is added.
349
+ def push(item)
350
+ if size < @max_size
351
+ @queue.push(item)
352
+ elsif item[0] < @queue.peek[0]
353
+ @queue.pop
354
+ @queue.push(item)
355
+ end
356
+ end
357
+
358
+ # Returns the item with the smallest priority without removing it from
359
+ # the BoundedPriorityQueue.
360
+ def peek
361
+ @queue.peek
362
+ end
363
+
364
+ def to_a
365
+ @queue.to_a
366
+ end
367
+ end
368
+
369
+ # PriorityQueue is a data structure that keeps elements ordered by priority.
370
+ # It supports inserting elements, removing the element with the smallest
371
+ # priority, and peeking at the element with the smallest priority. It uses
372
+ # a binary heap as the underlying data structure.
373
+ class PriorityQueue
374
+ def initialize
375
+ @elements = []
376
+ end
377
+
378
+ def size
379
+ @elements.size
380
+ end
381
+
382
+ def empty?
383
+ @elements.empty?
384
+ end
385
+
386
+ def push(item)
387
+ @elements << item
388
+ shift_up(@elements.size - 1)
389
+ end
390
+
391
+ # Removes and returns the element with the smallest priority.
392
+ # Returns nil if the PriorityQueue is empty.
393
+ def pop
394
+ return if empty?
395
+
396
+ swap(0, @elements.size - 1)
397
+ element = @elements.pop
398
+ shift_down(0)
399
+ element
400
+ end
401
+
402
+ # Returns the element with the smallest priority without removing it
403
+ # from the PriorityQueue.
404
+ def peek
405
+ @elements.first
406
+ end
407
+
408
+ def to_a
409
+ @elements.dup
410
+ end
411
+
412
+ private
413
+
414
+ def swap(i, j)
415
+ @elements[i], @elements[j] = @elements[j], @elements[i]
416
+ end
417
+
418
+ def shift_up(i)
419
+ parent = (i - 1) / 2
420
+ return if i <= 0 || @elements[parent][0] <= @elements[i][0]
421
+
422
+ swap(i, parent)
423
+ shift_up(parent)
424
+ end
425
+
426
+ def shift_down(i)
427
+ left_child = 2 * i + 1
428
+ right_child = 2 * i + 2
429
+
430
+ min = i
431
+ min = left_child if left_child < size && @elements[left_child][0] < @elements[min][0]
432
+ min = right_child if right_child < size && @elements[right_child][0] < @elements[min][0]
433
+
434
+ return if min == i
435
+
436
+ swap(i, min)
437
+ shift_down(min)
438
+ end
439
+ end
440
+ end
441
+ end
442
+ end
@@ -0,0 +1,27 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: hnsw.proto
3
+
4
+ require "google/protobuf"
5
+
6
+ Google::Protobuf::DescriptorPool.generated_pool.build do
7
+ add_file("hnsw.proto", syntax: :proto3) do
8
+ add_message "HNSWGraphNode" do
9
+ optional :id, :string, 1
10
+ repeated :vector, :float, 2
11
+ optional :level, :int32, 3
12
+ repeated :neighbors, :string, 4
13
+ end
14
+ add_message "HNSWGraph" do
15
+ optional :entrypoint_id, :string, 1
16
+ optional :max_level, :int32, 2
17
+ optional :similarity_metric, :string, 3
18
+ optional :dimensions, :int32, 4
19
+ optional :m, :int32, 5
20
+ optional :ef, :int32, 6
21
+ repeated :nodes, :message, 7, "HNSWGraphNode"
22
+ end
23
+ end
24
+ end
25
+
26
+ HNSWGraphNode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraphNode").msgclass
27
+ HNSWGraph = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("HNSWGraph").msgclass