leann 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,476 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module Leann
6
+ module Backend
7
+ # LEANN Graph-only backend - stores only the HNSW graph structure
8
+ # Achieves ~97% storage reduction by not storing embeddings
9
+ #
10
+ # Graph is stored as:
11
+ # - Node levels (which layers each node participates in)
12
+ # - Neighbor lists per level
13
+ # - Entry point and HNSW parameters
14
+ #
15
+ # During search, embeddings are recomputed on-the-fly via API calls
16
+ #
17
+ class LeannGraph
18
+ # HNSW parameters
19
+ DEFAULT_M = 16 # Max connections per layer
20
+ DEFAULT_EF_CONSTRUCTION = 200 # Build-time search width
21
+ DEFAULT_ML = 1.0 / Math.log(DEFAULT_M) # Level multiplier
22
+
23
+ attr_reader :dimensions, :m, :ef_construction, :max_level, :entry_point
24
+ attr_reader :node_count
25
+
26
+ def initialize(dimensions:, m: DEFAULT_M, ef_construction: DEFAULT_EF_CONSTRUCTION)
27
+ @dimensions = dimensions
28
+ @m = m
29
+ @m0 = m * 2 # Layer 0 has 2x connections
30
+ @ef_construction = ef_construction
31
+ @ml = 1.0 / Math.log(m)
32
+
33
+ @nodes = [] # Array of node data {id:, level:}
34
+ @id_to_idx = {} # Map document ID to node index
35
+ @neighbors = [] # neighbors[idx][level] = [neighbor_indices]
36
+ @entry_point = nil
37
+ @max_level = -1
38
+ @node_count = 0
39
+ end
40
+
41
+ # Build the graph from embeddings
42
+ # After building, embeddings can be discarded
43
+ #
44
+ # @param ids [Array<String>] Document IDs
45
+ # @param embeddings [Array<Array<Float>>] Embedding vectors
46
+ # @return [self]
47
+ def build(ids, embeddings)
48
+ raise ArgumentError, "IDs and embeddings must have same length" unless ids.length == embeddings.length
49
+ return self if ids.empty?
50
+
51
+ @node_count = ids.length
52
+ puts "Building LEANN graph with #{@node_count} nodes (M=#{@m})..."
53
+
54
+ ids.each_with_index do |id, idx|
55
+ level = random_level
56
+ @nodes << { id: id, level: level }
57
+ @id_to_idx[id] = idx
58
+ @neighbors << Array.new(level + 1) { [] }
59
+
60
+ if @entry_point.nil?
61
+ @entry_point = idx
62
+ @max_level = level
63
+ else
64
+ # Insert node into graph
65
+ insert_node(idx, embeddings[idx], embeddings, level)
66
+ end
67
+
68
+ print "." if (idx + 1) % 100 == 0
69
+ end
70
+
71
+ puts "\nGraph built: #{@node_count} nodes, max_level=#{@max_level}"
72
+ self
73
+ end
74
+
75
+ # Save graph to files (no embeddings!)
76
+ #
77
+ # @param path [String] Base path for index files
78
+ def save(path)
79
+ graph_file = "#{path}.graph.bin"
80
+ meta_file = "#{path}.graph.meta.json"
81
+
82
+ # Save metadata
83
+ meta = {
84
+ version: "1.0",
85
+ format: "leann_graph",
86
+ node_count: @node_count,
87
+ dimensions: @dimensions,
88
+ m: @m,
89
+ ef_construction: @ef_construction,
90
+ max_level: @max_level,
91
+ entry_point: @entry_point
92
+ }
93
+ File.write(meta_file, JSON.pretty_generate(meta))
94
+
95
+ # Save graph in binary format
96
+ File.open(graph_file, "wb") do |f|
97
+ # Header
98
+ f.write([@node_count].pack("Q<")) # uint64 node count
99
+
100
+ # Node levels
101
+ levels = @nodes.map { |n| n[:level] }
102
+ f.write(levels.pack("l<*")) # int32 array
103
+
104
+ # Node IDs (as length-prefixed strings)
105
+ @nodes.each do |node|
106
+ id_bytes = node[:id].to_s.encode("UTF-8")
107
+ f.write([id_bytes.bytesize].pack("S<")) # uint16 length
108
+ f.write(id_bytes)
109
+ end
110
+
111
+ # Neighbor lists (CSR-like format)
112
+ # First: offsets into neighbor data for each node
113
+ offsets = []
114
+ current_offset = 0
115
+ @neighbors.each do |node_neighbors|
116
+ offsets << current_offset
117
+ node_neighbors.each do |level_neighbors|
118
+ current_offset += level_neighbors.length
119
+ end
120
+ end
121
+ offsets << current_offset # Final offset
122
+
123
+ f.write(offsets.pack("Q<*")) # uint64 array
124
+
125
+ # Level offsets within each node
126
+ level_offsets = []
127
+ @neighbors.each do |node_neighbors|
128
+ level_offset = 0
129
+ node_neighbors.each do |level_neighbors|
130
+ level_offsets << level_offset
131
+ level_offset += level_neighbors.length
132
+ end
133
+ level_offsets << level_offset # End marker for this node
134
+ end
135
+ f.write([level_offsets.length].pack("Q<"))
136
+ f.write(level_offsets.pack("Q<*"))
137
+
138
+ # All neighbor indices (flattened)
139
+ all_neighbors = @neighbors.flat_map { |nn| nn.flat_map { |ln| ln } }
140
+ f.write([all_neighbors.length].pack("Q<"))
141
+ f.write(all_neighbors.pack("l<*")) # int32 array
142
+ end
143
+
144
+ graph_size = File.size(graph_file)
145
+ puts "Graph saved: #{format_bytes(graph_size)} (no embeddings stored!)"
146
+ end
147
+
148
+ # Load graph from files
149
+ #
150
+ # @param path [String] Base path for index files
151
+ # @return [LeannGraph]
152
+ def self.load(path)
153
+ graph_file = "#{path}.graph.bin"
154
+ meta_file = "#{path}.graph.meta.json"
155
+
156
+ raise IndexNotFoundError, "Graph file not found: #{graph_file}" unless File.exist?(graph_file)
157
+
158
+ meta = JSON.parse(File.read(meta_file), symbolize_names: true)
159
+
160
+ graph = new(
161
+ dimensions: meta[:dimensions],
162
+ m: meta[:m],
163
+ ef_construction: meta[:ef_construction]
164
+ )
165
+ graph.load_from_files(path, meta)
166
+ graph
167
+ end
168
+
169
+ # Load graph data from binary file
170
+ def load_from_files(path, meta)
171
+ graph_file = "#{path}.graph.bin"
172
+
173
+ @node_count = meta[:node_count]
174
+ @max_level = meta[:max_level]
175
+ @entry_point = meta[:entry_point]
176
+
177
+ File.open(graph_file, "rb") do |f|
178
+ # Read node count
179
+ node_count = f.read(8).unpack1("Q<")
180
+ raise "Node count mismatch" unless node_count == @node_count
181
+
182
+ # Read levels
183
+ levels = f.read(@node_count * 4).unpack("l<*")
184
+
185
+ # Read node IDs
186
+ @nodes = []
187
+ @id_to_idx = {}
188
+ @node_count.times do |idx|
189
+ id_len = f.read(2).unpack1("S<")
190
+ id = f.read(id_len).force_encoding("UTF-8")
191
+ @nodes << { id: id, level: levels[idx] }
192
+ @id_to_idx[id] = idx
193
+ end
194
+
195
+ # Read offsets
196
+ offsets = f.read((@node_count + 1) * 8).unpack("Q<*")
197
+
198
+ # Read level offsets
199
+ level_offsets_count = f.read(8).unpack1("Q<")
200
+ level_offsets = f.read(level_offsets_count * 8).unpack("Q<*")
201
+
202
+ # Read all neighbors
203
+ neighbors_count = f.read(8).unpack1("Q<")
204
+ all_neighbors = f.read(neighbors_count * 4).unpack("l<*")
205
+
206
+ # Reconstruct neighbor structure
207
+ @neighbors = []
208
+ level_offset_idx = 0
209
+ @node_count.times do |idx|
210
+ node_level = levels[idx]
211
+ node_neighbors = []
212
+ base_offset = offsets[idx]
213
+
214
+ (node_level + 1).times do |level|
215
+ start_off = level_offsets[level_offset_idx]
216
+ end_off = level_offsets[level_offset_idx + 1]
217
+ level_offset_idx += 1
218
+
219
+ level_neighbors = all_neighbors[(base_offset + start_off)...(base_offset + end_off)]
220
+ node_neighbors << (level_neighbors || [])
221
+ end
222
+ level_offset_idx += 1 # Skip end marker
223
+
224
+ @neighbors << node_neighbors
225
+ end
226
+ end
227
+
228
+ puts "Graph loaded: #{@node_count} nodes, max_level=#{@max_level}"
229
+ end
230
+
231
+ # Search the graph using on-the-fly embedding computation
232
+ #
233
+ # @param query_embedding [Array<Float>] Query vector
234
+ # @param embedding_provider [Embedding::Base] For recomputing embeddings
235
+ # @param passages [Hash] id => text mapping for recomputation
236
+ # @param limit [Integer] Number of results
237
+ # @param ef [Integer] Search width (higher = more accurate, slower)
238
+ # @return [Array<[String, Float]>] Array of [id, score] pairs
239
+ def search(query_embedding, embedding_provider:, passages:, limit: 5, ef: nil)
240
+ return [] if @entry_point.nil?
241
+
242
+ ef ||= [limit * 2, 10].max
243
+
244
+ # Cache for embeddings computed during this search
245
+ embedding_cache = {}
246
+
247
+ # Start from entry point, traverse down to layer 0
248
+ current = @entry_point
249
+ current_dist = distance(query_embedding, get_embedding(current, embedding_provider, passages, embedding_cache))
250
+
251
+ # Greedy search from top layer down to layer 1
252
+ (@max_level).downto(1) do |level|
253
+ changed = true
254
+ while changed
255
+ changed = false
256
+ neighbors = get_neighbors(current, level)
257
+ neighbors.each do |neighbor|
258
+ neighbor_emb = get_embedding(neighbor, embedding_provider, passages, embedding_cache)
259
+ neighbor_dist = distance(query_embedding, neighbor_emb)
260
+ if neighbor_dist < current_dist
261
+ current = neighbor
262
+ current_dist = neighbor_dist
263
+ changed = true
264
+ end
265
+ end
266
+ end
267
+ end
268
+
269
+ # Search layer 0 with ef-sized candidate set
270
+ candidates = search_layer(query_embedding, current, ef, 0, embedding_provider, passages, embedding_cache)
271
+
272
+ # Return top-k results, converted to similarity scores
273
+ candidates
274
+ .sort_by { |_, dist| dist }
275
+ .first(limit)
276
+ .map { |idx, dist| [@nodes[idx][:id], 1.0 - dist] } # Convert distance to similarity
277
+ end
278
+
279
+ # Get neighbors at a specific level
280
+ def get_neighbors(node_idx, level)
281
+ return [] if node_idx >= @neighbors.length
282
+ return [] if level >= @neighbors[node_idx].length
283
+ @neighbors[node_idx][level] || []
284
+ end
285
+
286
+ # Get document ID for a node index
287
+ def get_id(node_idx)
288
+ @nodes[node_idx][:id]
289
+ end
290
+
291
+ # Get node index for a document ID
292
+ def get_idx(id)
293
+ @id_to_idx[id]
294
+ end
295
+
296
+ private
297
+
298
+ def random_level
299
+ level = 0
300
+ while rand < (1.0 / @m) && level < 32
301
+ level += 1
302
+ end
303
+ level
304
+ end
305
+
306
+ def distance(a, b)
307
+ # Cosine distance = 1 - cosine_similarity
308
+ dot = 0.0
309
+ norm_a = 0.0
310
+ norm_b = 0.0
311
+ a.each_with_index do |val, i|
312
+ dot += val * b[i]
313
+ norm_a += val * val
314
+ norm_b += b[i] * b[i]
315
+ end
316
+ norm_a = Math.sqrt(norm_a)
317
+ norm_b = Math.sqrt(norm_b)
318
+ return 1.0 if norm_a == 0 || norm_b == 0
319
+ 1.0 - (dot / (norm_a * norm_b))
320
+ end
321
+
322
+ def get_embedding(node_idx, embedding_provider, passages, cache)
323
+ return cache[node_idx] if cache.key?(node_idx)
324
+
325
+ id = @nodes[node_idx][:id]
326
+ text = passages[id]
327
+ raise "Passage not found for ID: #{id}" unless text
328
+
329
+ embedding = embedding_provider.compute_one(text)
330
+ cache[node_idx] = embedding
331
+ embedding
332
+ end
333
+
334
+ def insert_node(idx, embedding, all_embeddings, level)
335
+ # Find entry point for this insert
336
+ ep = @entry_point
337
+ ep_dist = distance(embedding, all_embeddings[ep])
338
+
339
+ # Traverse from top to insertion level + 1
340
+ (@max_level).downto(level + 1) do |lc|
341
+ changed = true
342
+ while changed
343
+ changed = false
344
+ get_neighbors(ep, lc).each do |neighbor|
345
+ d = distance(embedding, all_embeddings[neighbor])
346
+ if d < ep_dist
347
+ ep = neighbor
348
+ ep_dist = d
349
+ changed = true
350
+ end
351
+ end
352
+ end
353
+ end
354
+
355
+ # Insert at each level from insertion level down to 0
356
+ [level, @max_level].min.downto(0) do |lc|
357
+ # Search for closest neighbors at this level
358
+ max_neighbors = lc == 0 ? @m0 : @m
359
+
360
+ candidates = search_layer_build(embedding, ep, @ef_construction, lc, all_embeddings)
361
+ neighbors = select_neighbors(embedding, candidates, max_neighbors, all_embeddings)
362
+
363
+ # Add edges
364
+ @neighbors[idx][lc] = neighbors
365
+
366
+ # Add reverse edges
367
+ neighbors.each do |neighbor|
368
+ neighbor_neighbors = @neighbors[neighbor][lc]
369
+ if neighbor_neighbors.length < max_neighbors
370
+ neighbor_neighbors << idx
371
+ else
372
+ # Check if we should replace a neighbor
373
+ candidates_with_new = neighbor_neighbors + [idx]
374
+ new_neighbors = select_neighbors(
375
+ all_embeddings[neighbor],
376
+ candidates_with_new.map { |n| [n, distance(all_embeddings[neighbor], all_embeddings[n])] },
377
+ max_neighbors,
378
+ all_embeddings
379
+ )
380
+ @neighbors[neighbor][lc] = new_neighbors
381
+ end
382
+ end
383
+
384
+ ep = neighbors.first if neighbors.any?
385
+ end
386
+
387
+ # Update entry point if needed
388
+ if level > @max_level
389
+ @entry_point = idx
390
+ @max_level = level
391
+ end
392
+ end
393
+
394
+ def search_layer_build(query_emb, entry_point, ef, level, all_embeddings)
395
+ visited = Set.new([entry_point])
396
+ candidates = [[entry_point, distance(query_emb, all_embeddings[entry_point])]]
397
+ results = [[entry_point, candidates.first[1]]]
398
+
399
+ while candidates.any?
400
+ # Get closest unprocessed candidate
401
+ candidates.sort_by! { |_, d| d }
402
+ current, current_dist = candidates.shift
403
+
404
+ # Stop if we've found enough and current is worse than worst result
405
+ break if results.length >= ef && current_dist > results.last[1]
406
+
407
+ # Explore neighbors
408
+ get_neighbors(current, level).each do |neighbor|
409
+ next if visited.include?(neighbor)
410
+ visited.add(neighbor)
411
+
412
+ neighbor_dist = distance(query_emb, all_embeddings[neighbor])
413
+
414
+ if results.length < ef || neighbor_dist < results.last[1]
415
+ candidates << [neighbor, neighbor_dist]
416
+ results << [neighbor, neighbor_dist]
417
+ results.sort_by! { |_, d| d }
418
+ results.pop if results.length > ef
419
+ end
420
+ end
421
+ end
422
+
423
+ results
424
+ end
425
+
426
+ def search_layer(query_emb, entry_point, ef, level, embedding_provider, passages, cache)
427
+ visited = Set.new([entry_point])
428
+ entry_emb = get_embedding(entry_point, embedding_provider, passages, cache)
429
+ candidates = [[entry_point, distance(query_emb, entry_emb)]]
430
+ results = [[entry_point, candidates.first[1]]]
431
+
432
+ while candidates.any?
433
+ candidates.sort_by! { |_, d| d }
434
+ current, current_dist = candidates.shift
435
+
436
+ break if results.length >= ef && current_dist > results.last[1]
437
+
438
+ get_neighbors(current, level).each do |neighbor|
439
+ next if visited.include?(neighbor)
440
+ visited.add(neighbor)
441
+
442
+ neighbor_emb = get_embedding(neighbor, embedding_provider, passages, cache)
443
+ neighbor_dist = distance(query_emb, neighbor_emb)
444
+
445
+ if results.length < ef || neighbor_dist < results.last[1]
446
+ candidates << [neighbor, neighbor_dist]
447
+ results << [neighbor, neighbor_dist]
448
+ results.sort_by! { |_, d| d }
449
+ results.pop if results.length > ef
450
+ end
451
+ end
452
+ end
453
+
454
+ results
455
+ end
456
+
457
+ def select_neighbors(query_emb, candidates, max_count, all_embeddings)
458
+ # Simple selection: take closest
459
+ candidates
460
+ .sort_by { |_, d| d.is_a?(Array) ? d[1] : d }
461
+ .first(max_count)
462
+ .map { |n, _| n.is_a?(Array) ? n[0] : n }
463
+ end
464
+
465
+ def format_bytes(bytes)
466
+ if bytes < 1024
467
+ "#{bytes} B"
468
+ elsif bytes < 1024 * 1024
469
+ "#{(bytes / 1024.0).round(1)} KB"
470
+ else
471
+ "#{(bytes / (1024.0 * 1024)).round(2)} MB"
472
+ end
473
+ end
474
+ end
475
+ end
476
+ end