clusterkit 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +3236 -0
- data/ext/clusterkit/Cargo.toml +2 -1
- data/ext/clusterkit/extconf.rb +9 -1
- data/ext/clusterkit/src/clustering/hdbscan_wrapper.rs +23 -36
- data/ext/clusterkit/src/clustering.rs +47 -53
- data/ext/clusterkit/src/embedder.rs +44 -52
- data/ext/clusterkit/src/hnsw.rs +181 -215
- data/ext/clusterkit/src/lib.rs +5 -5
- data/ext/clusterkit/src/svd.rs +31 -33
- data/ext/clusterkit/src/utils.rs +24 -21
- data/lib/clusterkit/version.rb +1 -1
- data/lib/clusterkit.rb +1 -1
- metadata +18 -4
- data/clusterkit.gemspec +0 -45
data/ext/clusterkit/src/hnsw.rs
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
use magnus::{
|
|
2
|
-
|
|
3
|
-
Error, Float, Integer, RArray, RHash, RString, Symbol, Value,
|
|
2
|
+
function, method, prelude::*,
|
|
3
|
+
Error, Float, Integer, RArray, RHash, RString, Symbol, Value, TryConvert, Ruby,
|
|
4
|
+
r_hash::ForEach,
|
|
4
5
|
};
|
|
5
6
|
use hnsw_rs::prelude::*;
|
|
6
7
|
use hnsw_rs::hnswio::HnswIo;
|
|
7
|
-
// use ndarray::Array1; // Not used currently
|
|
8
8
|
use std::collections::HashMap;
|
|
9
9
|
use std::sync::{Arc, Mutex};
|
|
10
10
|
use serde::{Serialize, Deserialize};
|
|
@@ -30,7 +30,7 @@ pub struct HnswIndex {
|
|
|
30
30
|
}
|
|
31
31
|
|
|
32
32
|
#[derive(Clone, Copy)]
|
|
33
|
-
#[allow(dead_code)]
|
|
33
|
+
#[allow(dead_code)]
|
|
34
34
|
enum DistanceType {
|
|
35
35
|
Euclidean,
|
|
36
36
|
Cosine,
|
|
@@ -38,88 +38,83 @@ enum DistanceType {
|
|
|
38
38
|
}
|
|
39
39
|
|
|
40
40
|
impl HnswIndex {
|
|
41
|
-
// Initialize a new HNSW index
|
|
42
41
|
pub fn new(kwargs: RHash) -> Result<Self, Error> {
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
let
|
|
42
|
+
let ruby = Ruby::get().unwrap();
|
|
43
|
+
|
|
44
|
+
let dim_opt: Option<Value> = kwargs.delete(ruby.to_symbol("dim"))?;
|
|
45
|
+
let dim_value = dim_opt.ok_or_else(|| Error::new(ruby.exception_arg_error(), "dim is required"))?;
|
|
46
46
|
let dim: usize = TryConvert::try_convert(dim_value)
|
|
47
|
-
.map_err(|_| Error::new(
|
|
48
|
-
|
|
49
|
-
// Validate dimension
|
|
47
|
+
.map_err(|_| Error::new(ruby.exception_arg_error(), "dim must be an integer"))?;
|
|
48
|
+
|
|
50
49
|
if dim == 0 {
|
|
51
|
-
return Err(Error::new(
|
|
50
|
+
return Err(Error::new(ruby.exception_arg_error(), "dim must be a positive integer (got 0)"));
|
|
52
51
|
}
|
|
53
|
-
|
|
54
|
-
let space: String = if let Some(v) = kwargs.delete(
|
|
55
|
-
// Convert Ruby symbol to string properly
|
|
52
|
+
|
|
53
|
+
let space: String = if let Some(v) = kwargs.delete(ruby.to_symbol("space"))? {
|
|
56
54
|
if let Ok(sym) = Symbol::try_convert(v) {
|
|
57
55
|
sym.name()?.to_string()
|
|
58
56
|
} else if let Ok(s) = String::try_convert(v) {
|
|
59
57
|
s
|
|
60
58
|
} else {
|
|
61
59
|
return Err(Error::new(
|
|
62
|
-
|
|
60
|
+
ruby.exception_type_error(),
|
|
63
61
|
"space must be a string or symbol"
|
|
64
62
|
));
|
|
65
63
|
}
|
|
66
64
|
} else {
|
|
67
65
|
"euclidean".to_string()
|
|
68
66
|
};
|
|
69
|
-
|
|
70
|
-
let max_elements: usize = if let Some(v) = kwargs.delete(
|
|
67
|
+
|
|
68
|
+
let max_elements: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("max_elements"))? {
|
|
71
69
|
TryConvert::try_convert(v).unwrap_or(10_000)
|
|
72
70
|
} else {
|
|
73
71
|
10_000
|
|
74
72
|
};
|
|
75
|
-
|
|
76
|
-
let m: usize = if let Some(v) = kwargs.delete(
|
|
73
|
+
|
|
74
|
+
let m: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("M"))? {
|
|
77
75
|
TryConvert::try_convert(v).unwrap_or(16)
|
|
78
76
|
} else {
|
|
79
77
|
16
|
|
80
78
|
};
|
|
81
|
-
|
|
82
|
-
let ef_construction: usize = if let Some(v) = kwargs.delete(
|
|
79
|
+
|
|
80
|
+
let ef_construction: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("ef_construction"))? {
|
|
83
81
|
TryConvert::try_convert(v).unwrap_or(200)
|
|
84
82
|
} else {
|
|
85
83
|
200
|
|
86
84
|
};
|
|
87
|
-
|
|
88
|
-
let random_seed: Option<u64> = if let Some(v) = kwargs.delete(
|
|
85
|
+
|
|
86
|
+
let random_seed: Option<u64> = if let Some(v) = kwargs.delete(ruby.to_symbol("random_seed"))? {
|
|
89
87
|
TryConvert::try_convert(v).ok()
|
|
90
88
|
} else {
|
|
91
89
|
None
|
|
92
90
|
};
|
|
93
|
-
|
|
94
|
-
// Validate and convert space parameter
|
|
95
|
-
// For now, only support Euclidean distance
|
|
91
|
+
|
|
96
92
|
let distance_type = match space.as_str() {
|
|
97
93
|
"euclidean" => DistanceType::Euclidean,
|
|
98
94
|
"cosine" => {
|
|
99
95
|
return Err(Error::new(
|
|
100
|
-
|
|
96
|
+
ruby.exception_runtime_error(),
|
|
101
97
|
"Cosine distance is not yet implemented, please use :euclidean"
|
|
102
98
|
));
|
|
103
99
|
},
|
|
104
100
|
"inner_product" => {
|
|
105
101
|
return Err(Error::new(
|
|
106
|
-
|
|
102
|
+
ruby.exception_runtime_error(),
|
|
107
103
|
"Inner product distance is not yet implemented, please use :euclidean"
|
|
108
104
|
));
|
|
109
105
|
},
|
|
110
106
|
_ => return Err(Error::new(
|
|
111
|
-
|
|
107
|
+
ruby.exception_arg_error(),
|
|
112
108
|
format!("space must be :euclidean, :cosine, or :inner_product (got: {})", space)
|
|
113
109
|
)),
|
|
114
110
|
};
|
|
115
|
-
|
|
116
|
-
// Create HNSW instance with Euclidean distance
|
|
111
|
+
|
|
117
112
|
let hnsw = if let Some(seed) = random_seed {
|
|
118
113
|
Hnsw::<f32, DistL2>::new_with_seed(m, max_elements, 16, ef_construction, DistL2, seed)
|
|
119
114
|
} else {
|
|
120
115
|
Hnsw::<f32, DistL2>::new(m, max_elements, 16, ef_construction, DistL2)
|
|
121
116
|
};
|
|
122
|
-
|
|
117
|
+
|
|
123
118
|
Ok(Self {
|
|
124
119
|
hnsw: Arc::new(Mutex::new(hnsw)),
|
|
125
120
|
dim,
|
|
@@ -130,14 +125,13 @@ impl HnswIndex {
|
|
|
130
125
|
ef_search: Arc::new(Mutex::new(ef_construction)),
|
|
131
126
|
})
|
|
132
127
|
}
|
|
133
|
-
|
|
134
|
-
// Add a single item to the index
|
|
128
|
+
|
|
135
129
|
pub fn add_item(&self, vector: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
let label: String = if let Some(v) = kwargs.delete(
|
|
130
|
+
let ruby = Ruby::get().unwrap();
|
|
131
|
+
|
|
132
|
+
let vec_data = parse_vector(&ruby, vector, self.dim)?;
|
|
133
|
+
|
|
134
|
+
let label: String = if let Some(v) = kwargs.delete(ruby.to_symbol("label"))? {
|
|
141
135
|
TryConvert::try_convert(v).unwrap_or_else(|_| {
|
|
142
136
|
let mut id = self.current_id.lock().unwrap();
|
|
143
137
|
let label = id.to_string();
|
|
@@ -150,33 +144,30 @@ impl HnswIndex {
|
|
|
150
144
|
*id += 1;
|
|
151
145
|
label
|
|
152
146
|
};
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
Some(parse_metadata(v)?)
|
|
147
|
+
|
|
148
|
+
let metadata: Option<HashMap<String, String>> = if let Some(v) = kwargs.delete(ruby.to_symbol("metadata"))? {
|
|
149
|
+
Some(parse_metadata(&ruby, v)?)
|
|
157
150
|
} else {
|
|
158
151
|
None
|
|
159
152
|
};
|
|
160
|
-
|
|
161
|
-
// Get internal ID for this item
|
|
153
|
+
|
|
162
154
|
let internal_id = {
|
|
163
155
|
let mut label_map = self.label_to_id.lock().unwrap();
|
|
164
156
|
let mut current_id = self.current_id.lock().unwrap();
|
|
165
|
-
|
|
157
|
+
|
|
166
158
|
if label_map.contains_key(&label) {
|
|
167
159
|
return Err(Error::new(
|
|
168
|
-
|
|
160
|
+
ruby.exception_arg_error(),
|
|
169
161
|
format!("Label '{}' already exists in index", label)
|
|
170
162
|
));
|
|
171
163
|
}
|
|
172
|
-
|
|
164
|
+
|
|
173
165
|
let id = *current_id;
|
|
174
166
|
label_map.insert(label.clone(), id);
|
|
175
167
|
*current_id += 1;
|
|
176
168
|
id
|
|
177
169
|
};
|
|
178
|
-
|
|
179
|
-
// Store metadata
|
|
170
|
+
|
|
180
171
|
{
|
|
181
172
|
let mut metadata_store = self.metadata_store.lock().unwrap();
|
|
182
173
|
metadata_store.insert(internal_id, ItemMetadata {
|
|
@@ -184,39 +175,38 @@ impl HnswIndex {
|
|
|
184
175
|
metadata,
|
|
185
176
|
});
|
|
186
177
|
}
|
|
187
|
-
|
|
188
|
-
// Add to HNSW
|
|
178
|
+
|
|
189
179
|
{
|
|
190
180
|
let hnsw = self.hnsw.lock().unwrap();
|
|
191
181
|
hnsw.insert((&vec_data, internal_id));
|
|
192
182
|
}
|
|
193
|
-
|
|
194
|
-
Ok(
|
|
183
|
+
|
|
184
|
+
Ok(ruby.qnil().as_value())
|
|
195
185
|
}
|
|
196
|
-
|
|
197
|
-
// Add multiple items in batch
|
|
186
|
+
|
|
198
187
|
pub fn add_batch(&self, vectors: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
199
|
-
let
|
|
188
|
+
let ruby = Ruby::get().unwrap();
|
|
189
|
+
|
|
190
|
+
let parallel: bool = if let Some(v) = kwargs.delete(ruby.to_symbol("parallel"))? {
|
|
200
191
|
TryConvert::try_convert(v).unwrap_or(true)
|
|
201
192
|
} else {
|
|
202
193
|
true
|
|
203
194
|
};
|
|
204
|
-
|
|
205
|
-
let labels: Option<RArray> = if let Some(v) = kwargs.delete(
|
|
195
|
+
|
|
196
|
+
let labels: Option<RArray> = if let Some(v) = kwargs.delete(ruby.to_symbol("labels"))? {
|
|
206
197
|
TryConvert::try_convert(v).ok()
|
|
207
198
|
} else {
|
|
208
199
|
None
|
|
209
200
|
};
|
|
210
|
-
|
|
211
|
-
// Parse all vectors
|
|
201
|
+
|
|
212
202
|
let mut data_points: Vec<(Vec<f32>, usize)> = Vec::new();
|
|
213
203
|
let mut metadata_entries: Vec<(usize, ItemMetadata)> = Vec::new();
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
let
|
|
218
|
-
|
|
219
|
-
|
|
204
|
+
|
|
205
|
+
let len = vectors.len();
|
|
206
|
+
for i in 0..len {
|
|
207
|
+
let vector: RArray = vectors.entry(i as isize)?;
|
|
208
|
+
let vec_data = parse_vector(&ruby, vector, self.dim)?;
|
|
209
|
+
|
|
220
210
|
let label = if let Some(ref labels_array) = labels {
|
|
221
211
|
labels_array.entry::<String>(i as isize)?
|
|
222
212
|
} else {
|
|
@@ -225,41 +215,38 @@ impl HnswIndex {
|
|
|
225
215
|
*id += 1;
|
|
226
216
|
label
|
|
227
217
|
};
|
|
228
|
-
|
|
229
|
-
// Get internal ID
|
|
218
|
+
|
|
230
219
|
let internal_id = {
|
|
231
220
|
let mut label_map = self.label_to_id.lock().unwrap();
|
|
232
221
|
let mut current_id = self.current_id.lock().unwrap();
|
|
233
|
-
|
|
222
|
+
|
|
234
223
|
if label_map.contains_key(&label) {
|
|
235
224
|
return Err(Error::new(
|
|
236
|
-
|
|
225
|
+
ruby.exception_arg_error(),
|
|
237
226
|
format!("Label '{}' already exists in index", label)
|
|
238
227
|
));
|
|
239
228
|
}
|
|
240
|
-
|
|
229
|
+
|
|
241
230
|
let id = *current_id;
|
|
242
231
|
label_map.insert(label.clone(), id);
|
|
243
232
|
*current_id += 1;
|
|
244
233
|
id
|
|
245
234
|
};
|
|
246
|
-
|
|
235
|
+
|
|
247
236
|
data_points.push((vec_data, internal_id));
|
|
248
237
|
metadata_entries.push((internal_id, ItemMetadata {
|
|
249
238
|
label,
|
|
250
239
|
metadata: None,
|
|
251
240
|
}));
|
|
252
241
|
}
|
|
253
|
-
|
|
254
|
-
// Store metadata
|
|
242
|
+
|
|
255
243
|
{
|
|
256
244
|
let mut metadata_store = self.metadata_store.lock().unwrap();
|
|
257
245
|
for (id, metadata) in metadata_entries {
|
|
258
246
|
metadata_store.insert(id, metadata);
|
|
259
247
|
}
|
|
260
248
|
}
|
|
261
|
-
|
|
262
|
-
// Insert into HNSW
|
|
249
|
+
|
|
263
250
|
{
|
|
264
251
|
let hnsw = self.hnsw.lock().unwrap();
|
|
265
252
|
if parallel {
|
|
@@ -271,57 +258,54 @@ impl HnswIndex {
|
|
|
271
258
|
}
|
|
272
259
|
}
|
|
273
260
|
}
|
|
274
|
-
|
|
275
|
-
Ok(
|
|
261
|
+
|
|
262
|
+
Ok(ruby.qnil().as_value())
|
|
276
263
|
}
|
|
277
|
-
|
|
278
|
-
// Search for k nearest neighbors
|
|
264
|
+
|
|
279
265
|
pub fn search(&self, query: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
280
|
-
let
|
|
266
|
+
let ruby = Ruby::get().unwrap();
|
|
267
|
+
|
|
268
|
+
let k: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("k"))? {
|
|
281
269
|
TryConvert::try_convert(v).unwrap_or(10)
|
|
282
270
|
} else {
|
|
283
271
|
10
|
|
284
272
|
};
|
|
285
|
-
|
|
286
|
-
let include_distances: bool = if let Some(v) = kwargs.delete(
|
|
273
|
+
|
|
274
|
+
let include_distances: bool = if let Some(v) = kwargs.delete(ruby.to_symbol("include_distances"))? {
|
|
287
275
|
TryConvert::try_convert(v).unwrap_or(false)
|
|
288
276
|
} else {
|
|
289
277
|
false
|
|
290
278
|
};
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
// Set search ef if provided
|
|
296
|
-
if let Some(v) = kwargs.delete(Symbol::new("ef"))? {
|
|
279
|
+
|
|
280
|
+
let query_vec = parse_vector(&ruby, query, self.dim)?;
|
|
281
|
+
|
|
282
|
+
if let Some(v) = kwargs.delete(ruby.to_symbol("ef"))? {
|
|
297
283
|
if let Ok(ef) = TryConvert::try_convert(v) as Result<usize, _> {
|
|
298
284
|
let mut ef_search = self.ef_search.lock().unwrap();
|
|
299
285
|
*ef_search = ef;
|
|
300
286
|
}
|
|
301
287
|
}
|
|
302
|
-
|
|
303
|
-
// Perform search
|
|
288
|
+
|
|
304
289
|
let neighbors = {
|
|
305
290
|
let hnsw = self.hnsw.lock().unwrap();
|
|
306
291
|
let ef_search = self.ef_search.lock().unwrap();
|
|
307
292
|
hnsw.search(&query_vec, k, *ef_search)
|
|
308
293
|
};
|
|
309
|
-
|
|
310
|
-
// Convert results
|
|
294
|
+
|
|
311
295
|
let metadata_store = self.metadata_store.lock().unwrap();
|
|
312
|
-
|
|
313
|
-
let indices =
|
|
314
|
-
let distances =
|
|
315
|
-
|
|
296
|
+
|
|
297
|
+
let indices = ruby.ary_new();
|
|
298
|
+
let distances = ruby.ary_new();
|
|
299
|
+
|
|
316
300
|
for neighbor in neighbors {
|
|
317
301
|
if let Some(metadata) = metadata_store.get(&neighbor.d_id) {
|
|
318
|
-
indices.push(
|
|
319
|
-
distances.push(
|
|
302
|
+
indices.push(ruby.str_new(&metadata.label))?;
|
|
303
|
+
distances.push(ruby.float_from_f64(neighbor.distance as f64))?;
|
|
320
304
|
}
|
|
321
305
|
}
|
|
322
|
-
|
|
306
|
+
|
|
323
307
|
if include_distances {
|
|
324
|
-
let result =
|
|
308
|
+
let result = ruby.ary_new();
|
|
325
309
|
result.push(indices)?;
|
|
326
310
|
result.push(distances)?;
|
|
327
311
|
Ok(result.as_value())
|
|
@@ -329,114 +313,107 @@ impl HnswIndex {
|
|
|
329
313
|
Ok(indices.as_value())
|
|
330
314
|
}
|
|
331
315
|
}
|
|
332
|
-
|
|
333
|
-
// Search with metadata included
|
|
316
|
+
|
|
334
317
|
pub fn search_with_metadata(&self, query: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
335
|
-
let
|
|
318
|
+
let ruby = Ruby::get().unwrap();
|
|
319
|
+
|
|
320
|
+
let k: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("k"))? {
|
|
336
321
|
TryConvert::try_convert(v).unwrap_or(10)
|
|
337
322
|
} else {
|
|
338
323
|
10
|
|
339
324
|
};
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
// Perform search
|
|
325
|
+
|
|
326
|
+
let query_vec = parse_vector(&ruby, query, self.dim)?;
|
|
327
|
+
|
|
345
328
|
let neighbors = {
|
|
346
329
|
let hnsw = self.hnsw.lock().unwrap();
|
|
347
330
|
let ef_search = self.ef_search.lock().unwrap();
|
|
348
331
|
hnsw.search(&query_vec, k, *ef_search)
|
|
349
332
|
};
|
|
350
|
-
|
|
351
|
-
// Build results with metadata
|
|
333
|
+
|
|
352
334
|
let metadata_store = self.metadata_store.lock().unwrap();
|
|
353
|
-
let results =
|
|
354
|
-
|
|
335
|
+
let results = ruby.ary_new();
|
|
336
|
+
|
|
355
337
|
for neighbor in neighbors {
|
|
356
338
|
if let Some(item_metadata) = metadata_store.get(&neighbor.d_id) {
|
|
357
|
-
let result =
|
|
358
|
-
result.aset(
|
|
359
|
-
result.aset(
|
|
360
|
-
|
|
361
|
-
let meta_hash =
|
|
339
|
+
let result = ruby.hash_new();
|
|
340
|
+
result.aset(ruby.to_symbol("label"), ruby.str_new(&item_metadata.label))?;
|
|
341
|
+
result.aset(ruby.to_symbol("distance"), ruby.float_from_f64(neighbor.distance as f64))?;
|
|
342
|
+
|
|
343
|
+
let meta_hash = ruby.hash_new();
|
|
362
344
|
if let Some(ref meta) = item_metadata.metadata {
|
|
363
345
|
for (key, value) in meta {
|
|
364
|
-
meta_hash.aset(
|
|
346
|
+
meta_hash.aset(ruby.str_new(key), ruby.str_new(value))?;
|
|
365
347
|
}
|
|
366
348
|
}
|
|
367
|
-
result.aset(
|
|
368
|
-
|
|
349
|
+
result.aset(ruby.to_symbol("metadata"), meta_hash)?;
|
|
350
|
+
|
|
369
351
|
results.push(result)?;
|
|
370
352
|
}
|
|
371
353
|
}
|
|
372
|
-
|
|
354
|
+
|
|
373
355
|
Ok(results.as_value())
|
|
374
356
|
}
|
|
375
|
-
|
|
376
|
-
// Get current size of the index
|
|
357
|
+
|
|
377
358
|
pub fn size(&self) -> Result<usize, Error> {
|
|
378
359
|
let metadata_store = self.metadata_store.lock().unwrap();
|
|
379
360
|
Ok(metadata_store.len())
|
|
380
361
|
}
|
|
381
|
-
|
|
382
|
-
// Check if index is empty
|
|
362
|
+
|
|
383
363
|
pub fn empty(&self) -> Result<bool, Error> {
|
|
384
364
|
Ok(self.size()? == 0)
|
|
385
365
|
}
|
|
386
|
-
|
|
387
|
-
// Set the ef parameter for search
|
|
366
|
+
|
|
388
367
|
pub fn set_ef(&self, ef: usize) -> Result<Value, Error> {
|
|
368
|
+
let ruby = Ruby::get().unwrap();
|
|
389
369
|
let mut ef_search = self.ef_search.lock().unwrap();
|
|
390
370
|
*ef_search = ef;
|
|
391
|
-
Ok(
|
|
371
|
+
Ok(ruby.qnil().as_value())
|
|
392
372
|
}
|
|
393
|
-
|
|
394
|
-
// Get configuration
|
|
373
|
+
|
|
395
374
|
pub fn config(&self) -> Result<RHash, Error> {
|
|
396
|
-
let
|
|
397
|
-
config.
|
|
398
|
-
|
|
375
|
+
let ruby = Ruby::get().unwrap();
|
|
376
|
+
let config = ruby.hash_new();
|
|
377
|
+
config.aset(ruby.to_symbol("dim"), ruby.integer_from_i64(self.dim as i64))?;
|
|
378
|
+
|
|
399
379
|
let space_str = match self.space {
|
|
400
380
|
DistanceType::Euclidean => "euclidean",
|
|
401
381
|
DistanceType::Cosine => "cosine",
|
|
402
382
|
DistanceType::InnerProduct => "inner_product",
|
|
403
383
|
};
|
|
404
|
-
config.aset(
|
|
405
|
-
|
|
384
|
+
config.aset(ruby.to_symbol("space"), ruby.str_new(space_str))?;
|
|
385
|
+
|
|
406
386
|
let ef_search = self.ef_search.lock().unwrap();
|
|
407
|
-
config.aset(
|
|
408
|
-
config.aset(
|
|
409
|
-
|
|
387
|
+
config.aset(ruby.to_symbol("ef"), ruby.integer_from_i64(*ef_search as i64))?;
|
|
388
|
+
config.aset(ruby.to_symbol("size"), ruby.integer_from_i64(self.size()? as i64))?;
|
|
389
|
+
|
|
410
390
|
Ok(config)
|
|
411
391
|
}
|
|
412
|
-
|
|
413
|
-
// Get statistics about the index
|
|
392
|
+
|
|
414
393
|
pub fn stats(&self) -> Result<RHash, Error> {
|
|
415
|
-
let
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
stats.aset(
|
|
419
|
-
|
|
394
|
+
let ruby = Ruby::get().unwrap();
|
|
395
|
+
let stats = ruby.hash_new();
|
|
396
|
+
|
|
397
|
+
stats.aset(ruby.to_symbol("size"), ruby.integer_from_i64(self.size()? as i64))?;
|
|
398
|
+
stats.aset(ruby.to_symbol("dim"), ruby.integer_from_i64(self.dim as i64))?;
|
|
399
|
+
|
|
420
400
|
let ef_search = self.ef_search.lock().unwrap();
|
|
421
|
-
stats.aset(
|
|
422
|
-
|
|
423
|
-
// TODO: Add more statistics from HNSW structure
|
|
424
|
-
|
|
401
|
+
stats.aset(ruby.to_symbol("ef_search"), ruby.integer_from_i64(*ef_search as i64))?;
|
|
402
|
+
|
|
425
403
|
Ok(stats)
|
|
426
404
|
}
|
|
427
|
-
|
|
428
|
-
// Load index from file (class method)
|
|
405
|
+
|
|
429
406
|
pub fn load(path: RString) -> Result<Self, Error> {
|
|
407
|
+
let ruby = Ruby::get().unwrap();
|
|
430
408
|
let path_str = path.to_string()?;
|
|
431
|
-
|
|
432
|
-
// Load metadata first to get dimensions and space
|
|
409
|
+
|
|
433
410
|
let metadata_path = format!("{}.metadata", path_str);
|
|
434
411
|
let metadata_file = File::open(&metadata_path)
|
|
435
|
-
.map_err(|e| Error::new(
|
|
436
|
-
|
|
412
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to open metadata file: {}", e)))?;
|
|
413
|
+
|
|
437
414
|
let (
|
|
438
415
|
_metadata_store,
|
|
439
|
-
_label_to_id,
|
|
416
|
+
_label_to_id,
|
|
440
417
|
_current_id,
|
|
441
418
|
_dim,
|
|
442
419
|
_space_str,
|
|
@@ -445,25 +422,19 @@ impl HnswIndex {
|
|
|
445
422
|
HashMap<String, usize>,
|
|
446
423
|
usize,
|
|
447
424
|
usize,
|
|
448
|
-
String,
|
|
425
|
+
String,
|
|
449
426
|
) = bincode::deserialize_from(metadata_file)
|
|
450
|
-
.map_err(|e| Error::new(
|
|
451
|
-
|
|
452
|
-
// Load HNSW structure
|
|
427
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to load metadata: {}", e)))?;
|
|
428
|
+
|
|
453
429
|
let hnsw_dir = format!("{}_hnsw_data", path_str);
|
|
454
430
|
let hnsw_path = std::path::Path::new(&hnsw_dir);
|
|
455
|
-
|
|
456
|
-
// Create HnswIo and leak it to get 'static lifetime
|
|
457
|
-
// This is a memory leak, but necessary due to hnsw_rs lifetime constraints
|
|
458
|
-
// The memory will never be freed until the program exits
|
|
431
|
+
|
|
459
432
|
let hnswio = Box::new(HnswIo::new(hnsw_path, "hnsw"));
|
|
460
433
|
let hnswio_static: &'static mut HnswIo = Box::leak(hnswio);
|
|
461
|
-
|
|
462
|
-
// Now we can load the HNSW with 'static lifetime
|
|
434
|
+
|
|
463
435
|
let hnsw: Hnsw<'static, f32, DistL2> = hnswio_static.load_hnsw()
|
|
464
|
-
.map_err(|e| Error::new(
|
|
465
|
-
|
|
466
|
-
// Use the loaded metadata
|
|
436
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to load HNSW index: {}", e)))?;
|
|
437
|
+
|
|
467
438
|
let metadata_store = _metadata_store;
|
|
468
439
|
let label_to_id = _label_to_id;
|
|
469
440
|
let current_id = _current_id;
|
|
@@ -472,12 +443,11 @@ impl HnswIndex {
|
|
|
472
443
|
"euclidean" => DistanceType::Euclidean,
|
|
473
444
|
"cosine" => DistanceType::Cosine,
|
|
474
445
|
"inner_product" => DistanceType::InnerProduct,
|
|
475
|
-
_ => return Err(Error::new(
|
|
446
|
+
_ => return Err(Error::new(ruby.exception_runtime_error(), "Unknown distance type in saved file")),
|
|
476
447
|
};
|
|
477
|
-
|
|
478
|
-
// Use default ef_construction as ef_search
|
|
448
|
+
|
|
479
449
|
let ef_search = 200;
|
|
480
|
-
|
|
450
|
+
|
|
481
451
|
Ok(Self {
|
|
482
452
|
hnsw: Arc::new(Mutex::new(hnsw)),
|
|
483
453
|
dim,
|
|
@@ -488,30 +458,27 @@ impl HnswIndex {
|
|
|
488
458
|
ef_search: Arc::new(Mutex::new(ef_search)),
|
|
489
459
|
})
|
|
490
460
|
}
|
|
491
|
-
|
|
492
|
-
// Save index to file
|
|
461
|
+
|
|
493
462
|
pub fn save(&self, path: RString) -> Result<Value, Error> {
|
|
463
|
+
let ruby = Ruby::get().unwrap();
|
|
494
464
|
let path_str = path.to_string()?;
|
|
495
|
-
|
|
496
|
-
// Create directory for HNSW structure
|
|
465
|
+
|
|
497
466
|
let hnsw_dir = format!("{}_hnsw_data", path_str);
|
|
498
467
|
std::fs::create_dir_all(&hnsw_dir)
|
|
499
|
-
.map_err(|e| Error::new(
|
|
500
|
-
|
|
501
|
-
// Save HNSW structure
|
|
468
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to create directory: {}", e)))?;
|
|
469
|
+
|
|
502
470
|
{
|
|
503
471
|
let hnsw = self.hnsw.lock().unwrap();
|
|
504
472
|
hnsw.file_dump(&std::path::Path::new(&hnsw_dir), "hnsw")
|
|
505
|
-
.map_err(|e| Error::new(
|
|
473
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to save HNSW: {}", e)))?;
|
|
506
474
|
}
|
|
507
|
-
|
|
508
|
-
// Save metadata
|
|
475
|
+
|
|
509
476
|
let metadata_path = format!("{}.metadata", path_str);
|
|
510
477
|
{
|
|
511
478
|
let metadata_store = self.metadata_store.lock().unwrap();
|
|
512
479
|
let label_to_id = self.label_to_id.lock().unwrap();
|
|
513
480
|
let current_id = self.current_id.lock().unwrap();
|
|
514
|
-
|
|
481
|
+
|
|
515
482
|
let metadata_data = (
|
|
516
483
|
&*metadata_store,
|
|
517
484
|
&*label_to_id,
|
|
@@ -523,56 +490,55 @@ impl HnswIndex {
|
|
|
523
490
|
DistanceType::InnerProduct => "inner_product",
|
|
524
491
|
},
|
|
525
492
|
);
|
|
526
|
-
|
|
493
|
+
|
|
527
494
|
let file = File::create(&metadata_path)
|
|
528
|
-
.map_err(|e| Error::new(
|
|
529
|
-
|
|
495
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to create metadata file: {}", e)))?;
|
|
496
|
+
|
|
530
497
|
bincode::serialize_into(file, &metadata_data)
|
|
531
|
-
.map_err(|e| Error::new(
|
|
498
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to save metadata: {}", e)))?;
|
|
532
499
|
}
|
|
533
|
-
|
|
534
|
-
Ok(
|
|
500
|
+
|
|
501
|
+
Ok(ruby.qnil().as_value())
|
|
535
502
|
}
|
|
536
503
|
}
|
|
537
504
|
|
|
538
505
|
// Helper function to parse a Ruby array into a Vec<f32>
|
|
539
|
-
fn parse_vector(array: RArray, expected_dim: usize) -> Result<Vec<f32>, Error> {
|
|
506
|
+
fn parse_vector(ruby: &Ruby, array: RArray, expected_dim: usize) -> Result<Vec<f32>, Error> {
|
|
540
507
|
let len = array.len();
|
|
541
508
|
if len != expected_dim {
|
|
542
509
|
return Err(Error::new(
|
|
543
|
-
|
|
510
|
+
ruby.exception_arg_error(),
|
|
544
511
|
format!("Vector dimension mismatch: expected {}, got {}", expected_dim, len)
|
|
545
512
|
));
|
|
546
513
|
}
|
|
547
|
-
|
|
514
|
+
|
|
548
515
|
let mut vec = Vec::with_capacity(len);
|
|
549
|
-
for
|
|
550
|
-
let value: f64 =
|
|
551
|
-
.map_err(|_| Error::new(exception::type_error(), "Vector elements must be numeric"))?;
|
|
516
|
+
for i in 0..len {
|
|
517
|
+
let value: f64 = array.entry(i as isize)?;
|
|
552
518
|
vec.push(value as f32);
|
|
553
519
|
}
|
|
554
|
-
|
|
520
|
+
|
|
555
521
|
Ok(vec)
|
|
556
522
|
}
|
|
557
523
|
|
|
558
524
|
// Helper function to parse metadata
|
|
559
|
-
fn parse_metadata(value: Value) -> Result<HashMap<String, String>, Error> {
|
|
525
|
+
fn parse_metadata(ruby: &Ruby, value: Value) -> Result<HashMap<String, String>, Error> {
|
|
560
526
|
let hash: RHash = TryConvert::try_convert(value)
|
|
561
|
-
.map_err(|_| Error::new(
|
|
562
|
-
|
|
527
|
+
.map_err(|_| Error::new(ruby.exception_type_error(), "Metadata must be a hash"))?;
|
|
528
|
+
|
|
563
529
|
let mut metadata = HashMap::new();
|
|
564
|
-
|
|
530
|
+
|
|
565
531
|
hash.foreach(|key: Value, value: Value| {
|
|
566
|
-
|
|
532
|
+
let ruby = Ruby::get().unwrap();
|
|
533
|
+
|
|
567
534
|
let key_str = if let Ok(s) = String::try_convert(key) {
|
|
568
535
|
s
|
|
569
536
|
} else if let Ok(sym) = Symbol::try_convert(key) {
|
|
570
537
|
sym.name()?.to_string()
|
|
571
538
|
} else {
|
|
572
|
-
return Err(Error::new(
|
|
539
|
+
return Err(Error::new(ruby.exception_type_error(), "Metadata keys must be strings or symbols"));
|
|
573
540
|
};
|
|
574
|
-
|
|
575
|
-
// Convert value to string (handle various Ruby types)
|
|
541
|
+
|
|
576
542
|
let value_str = if let Ok(s) = String::try_convert(value) {
|
|
577
543
|
s
|
|
578
544
|
} else if let Ok(i) = Integer::try_convert(value) {
|
|
@@ -580,22 +546,22 @@ fn parse_metadata(value: Value) -> Result<HashMap<String, String>, Error> {
|
|
|
580
546
|
} else if let Ok(f) = Float::try_convert(value) {
|
|
581
547
|
f.to_f64().to_string()
|
|
582
548
|
} else {
|
|
583
|
-
// Fallback: use Ruby's to_s method
|
|
584
549
|
let to_s_method = value.funcall::<_, _, RString>("to_s", ())?;
|
|
585
550
|
to_s_method.to_string()?
|
|
586
551
|
};
|
|
587
|
-
|
|
552
|
+
|
|
588
553
|
metadata.insert(key_str, value_str);
|
|
589
554
|
Ok(ForEach::Continue)
|
|
590
555
|
})?;
|
|
591
|
-
|
|
556
|
+
|
|
592
557
|
Ok(metadata)
|
|
593
558
|
}
|
|
594
559
|
|
|
595
560
|
// Initialize the HNSW module
|
|
596
561
|
pub fn init(parent: &magnus::RModule) -> Result<(), Error> {
|
|
597
|
-
let
|
|
598
|
-
|
|
562
|
+
let ruby = Ruby::get().unwrap();
|
|
563
|
+
let class = parent.define_class("HNSW", ruby.class_object())?;
|
|
564
|
+
|
|
599
565
|
class.define_singleton_method("new", function!(HnswIndex::new, 1))?;
|
|
600
566
|
class.define_singleton_method("load", function!(HnswIndex::load, 1))?;
|
|
601
567
|
class.define_method("add_item", method!(HnswIndex::add_item, 2))?;
|
|
@@ -608,6 +574,6 @@ pub fn init(parent: &magnus::RModule) -> Result<(), Error> {
|
|
|
608
574
|
class.define_method("config", method!(HnswIndex::config, 0))?;
|
|
609
575
|
class.define_method("stats", method!(HnswIndex::stats, 0))?;
|
|
610
576
|
class.define_method("save", method!(HnswIndex::save, 1))?;
|
|
611
|
-
|
|
577
|
+
|
|
612
578
|
Ok(())
|
|
613
|
-
}
|
|
579
|
+
}
|