clusterkit 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +3236 -0
- data/README.md +227 -7
- data/docs/KNOWN_ISSUES.md +5 -5
- data/docs/RUST_ERROR_HANDLING.md +6 -6
- data/docs/assets/clusterkit-wide.png +0 -0
- data/docs/assets/clusterkit.png +0 -0
- data/docs/assets/visualization.png +0 -0
- data/ext/clusterkit/Cargo.toml +5 -4
- data/ext/clusterkit/extconf.rb +9 -1
- data/ext/clusterkit/src/clustering/hdbscan_wrapper.rs +27 -62
- data/ext/clusterkit/src/clustering.rs +68 -114
- data/ext/clusterkit/src/embedder.rs +48 -131
- data/ext/clusterkit/src/hnsw.rs +579 -0
- data/ext/clusterkit/src/lib.rs +7 -5
- data/ext/clusterkit/src/svd.rs +35 -58
- data/ext/clusterkit/src/utils.rs +159 -9
- data/lib/clusterkit/clustering/hdbscan.rb +4 -17
- data/lib/clusterkit/clustering.rb +4 -23
- data/lib/clusterkit/data_validator.rb +132 -0
- data/lib/clusterkit/dimensionality/pca.rb +12 -12
- data/lib/clusterkit/dimensionality/svd.rb +47 -16
- data/lib/clusterkit/dimensionality/umap.rb +7 -40
- data/lib/clusterkit/hnsw.rb +251 -0
- data/lib/clusterkit/version.rb +1 -1
- data/lib/clusterkit.rb +2 -1
- metadata +40 -20
- data/clusterkit.gemspec +0 -45
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
use magnus::{
|
|
2
|
+
function, method, prelude::*,
|
|
3
|
+
Error, Float, Integer, RArray, RHash, RString, Symbol, Value, TryConvert, Ruby,
|
|
4
|
+
r_hash::ForEach,
|
|
5
|
+
};
|
|
6
|
+
use hnsw_rs::prelude::*;
|
|
7
|
+
use hnsw_rs::hnswio::HnswIo;
|
|
8
|
+
use std::collections::HashMap;
|
|
9
|
+
use std::sync::{Arc, Mutex};
|
|
10
|
+
use serde::{Serialize, Deserialize};
|
|
11
|
+
use std::fs::File;
|
|
12
|
+
|
|
13
|
+
// Store metadata alongside vectors
|
|
14
|
+
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
15
|
+
struct ItemMetadata {
|
|
16
|
+
label: String,
|
|
17
|
+
metadata: Option<HashMap<String, String>>,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
// Main HNSW wrapper struct
|
|
21
|
+
#[magnus::wrap(class = "ClusterKit::HNSW", free_immediately, size)]
|
|
22
|
+
pub struct HnswIndex {
|
|
23
|
+
hnsw: Arc<Mutex<Hnsw<'static, f32, DistL2>>>,
|
|
24
|
+
dim: usize,
|
|
25
|
+
space: DistanceType,
|
|
26
|
+
metadata_store: Arc<Mutex<HashMap<usize, ItemMetadata>>>,
|
|
27
|
+
current_id: Arc<Mutex<usize>>,
|
|
28
|
+
label_to_id: Arc<Mutex<HashMap<String, usize>>>,
|
|
29
|
+
ef_search: Arc<Mutex<usize>>,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
#[derive(Clone, Copy)]
|
|
33
|
+
#[allow(dead_code)]
|
|
34
|
+
enum DistanceType {
|
|
35
|
+
Euclidean,
|
|
36
|
+
Cosine,
|
|
37
|
+
InnerProduct,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
impl HnswIndex {
|
|
41
|
+
pub fn new(kwargs: RHash) -> Result<Self, Error> {
|
|
42
|
+
let ruby = Ruby::get().unwrap();
|
|
43
|
+
|
|
44
|
+
let dim_opt: Option<Value> = kwargs.delete(ruby.to_symbol("dim"))?;
|
|
45
|
+
let dim_value = dim_opt.ok_or_else(|| Error::new(ruby.exception_arg_error(), "dim is required"))?;
|
|
46
|
+
let dim: usize = TryConvert::try_convert(dim_value)
|
|
47
|
+
.map_err(|_| Error::new(ruby.exception_arg_error(), "dim must be an integer"))?;
|
|
48
|
+
|
|
49
|
+
if dim == 0 {
|
|
50
|
+
return Err(Error::new(ruby.exception_arg_error(), "dim must be a positive integer (got 0)"));
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
let space: String = if let Some(v) = kwargs.delete(ruby.to_symbol("space"))? {
|
|
54
|
+
if let Ok(sym) = Symbol::try_convert(v) {
|
|
55
|
+
sym.name()?.to_string()
|
|
56
|
+
} else if let Ok(s) = String::try_convert(v) {
|
|
57
|
+
s
|
|
58
|
+
} else {
|
|
59
|
+
return Err(Error::new(
|
|
60
|
+
ruby.exception_type_error(),
|
|
61
|
+
"space must be a string or symbol"
|
|
62
|
+
));
|
|
63
|
+
}
|
|
64
|
+
} else {
|
|
65
|
+
"euclidean".to_string()
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
let max_elements: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("max_elements"))? {
|
|
69
|
+
TryConvert::try_convert(v).unwrap_or(10_000)
|
|
70
|
+
} else {
|
|
71
|
+
10_000
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
let m: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("M"))? {
|
|
75
|
+
TryConvert::try_convert(v).unwrap_or(16)
|
|
76
|
+
} else {
|
|
77
|
+
16
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
let ef_construction: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("ef_construction"))? {
|
|
81
|
+
TryConvert::try_convert(v).unwrap_or(200)
|
|
82
|
+
} else {
|
|
83
|
+
200
|
|
84
|
+
};
|
|
85
|
+
|
|
86
|
+
let random_seed: Option<u64> = if let Some(v) = kwargs.delete(ruby.to_symbol("random_seed"))? {
|
|
87
|
+
TryConvert::try_convert(v).ok()
|
|
88
|
+
} else {
|
|
89
|
+
None
|
|
90
|
+
};
|
|
91
|
+
|
|
92
|
+
let distance_type = match space.as_str() {
|
|
93
|
+
"euclidean" => DistanceType::Euclidean,
|
|
94
|
+
"cosine" => {
|
|
95
|
+
return Err(Error::new(
|
|
96
|
+
ruby.exception_runtime_error(),
|
|
97
|
+
"Cosine distance is not yet implemented, please use :euclidean"
|
|
98
|
+
));
|
|
99
|
+
},
|
|
100
|
+
"inner_product" => {
|
|
101
|
+
return Err(Error::new(
|
|
102
|
+
ruby.exception_runtime_error(),
|
|
103
|
+
"Inner product distance is not yet implemented, please use :euclidean"
|
|
104
|
+
));
|
|
105
|
+
},
|
|
106
|
+
_ => return Err(Error::new(
|
|
107
|
+
ruby.exception_arg_error(),
|
|
108
|
+
format!("space must be :euclidean, :cosine, or :inner_product (got: {})", space)
|
|
109
|
+
)),
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
let hnsw = if let Some(seed) = random_seed {
|
|
113
|
+
Hnsw::<f32, DistL2>::new_with_seed(m, max_elements, 16, ef_construction, DistL2, seed)
|
|
114
|
+
} else {
|
|
115
|
+
Hnsw::<f32, DistL2>::new(m, max_elements, 16, ef_construction, DistL2)
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
Ok(Self {
|
|
119
|
+
hnsw: Arc::new(Mutex::new(hnsw)),
|
|
120
|
+
dim,
|
|
121
|
+
space: distance_type,
|
|
122
|
+
metadata_store: Arc::new(Mutex::new(HashMap::new())),
|
|
123
|
+
current_id: Arc::new(Mutex::new(0)),
|
|
124
|
+
label_to_id: Arc::new(Mutex::new(HashMap::new())),
|
|
125
|
+
ef_search: Arc::new(Mutex::new(ef_construction)),
|
|
126
|
+
})
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
pub fn add_item(&self, vector: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
130
|
+
let ruby = Ruby::get().unwrap();
|
|
131
|
+
|
|
132
|
+
let vec_data = parse_vector(&ruby, vector, self.dim)?;
|
|
133
|
+
|
|
134
|
+
let label: String = if let Some(v) = kwargs.delete(ruby.to_symbol("label"))? {
|
|
135
|
+
TryConvert::try_convert(v).unwrap_or_else(|_| {
|
|
136
|
+
let mut id = self.current_id.lock().unwrap();
|
|
137
|
+
let label = id.to_string();
|
|
138
|
+
*id += 1;
|
|
139
|
+
label
|
|
140
|
+
})
|
|
141
|
+
} else {
|
|
142
|
+
let mut id = self.current_id.lock().unwrap();
|
|
143
|
+
let label = id.to_string();
|
|
144
|
+
*id += 1;
|
|
145
|
+
label
|
|
146
|
+
};
|
|
147
|
+
|
|
148
|
+
let metadata: Option<HashMap<String, String>> = if let Some(v) = kwargs.delete(ruby.to_symbol("metadata"))? {
|
|
149
|
+
Some(parse_metadata(&ruby, v)?)
|
|
150
|
+
} else {
|
|
151
|
+
None
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
let internal_id = {
|
|
155
|
+
let mut label_map = self.label_to_id.lock().unwrap();
|
|
156
|
+
let mut current_id = self.current_id.lock().unwrap();
|
|
157
|
+
|
|
158
|
+
if label_map.contains_key(&label) {
|
|
159
|
+
return Err(Error::new(
|
|
160
|
+
ruby.exception_arg_error(),
|
|
161
|
+
format!("Label '{}' already exists in index", label)
|
|
162
|
+
));
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
let id = *current_id;
|
|
166
|
+
label_map.insert(label.clone(), id);
|
|
167
|
+
*current_id += 1;
|
|
168
|
+
id
|
|
169
|
+
};
|
|
170
|
+
|
|
171
|
+
{
|
|
172
|
+
let mut metadata_store = self.metadata_store.lock().unwrap();
|
|
173
|
+
metadata_store.insert(internal_id, ItemMetadata {
|
|
174
|
+
label: label.clone(),
|
|
175
|
+
metadata,
|
|
176
|
+
});
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
{
|
|
180
|
+
let hnsw = self.hnsw.lock().unwrap();
|
|
181
|
+
hnsw.insert((&vec_data, internal_id));
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
Ok(ruby.qnil().as_value())
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
pub fn add_batch(&self, vectors: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
188
|
+
let ruby = Ruby::get().unwrap();
|
|
189
|
+
|
|
190
|
+
let parallel: bool = if let Some(v) = kwargs.delete(ruby.to_symbol("parallel"))? {
|
|
191
|
+
TryConvert::try_convert(v).unwrap_or(true)
|
|
192
|
+
} else {
|
|
193
|
+
true
|
|
194
|
+
};
|
|
195
|
+
|
|
196
|
+
let labels: Option<RArray> = if let Some(v) = kwargs.delete(ruby.to_symbol("labels"))? {
|
|
197
|
+
TryConvert::try_convert(v).ok()
|
|
198
|
+
} else {
|
|
199
|
+
None
|
|
200
|
+
};
|
|
201
|
+
|
|
202
|
+
let mut data_points: Vec<(Vec<f32>, usize)> = Vec::new();
|
|
203
|
+
let mut metadata_entries: Vec<(usize, ItemMetadata)> = Vec::new();
|
|
204
|
+
|
|
205
|
+
let len = vectors.len();
|
|
206
|
+
for i in 0..len {
|
|
207
|
+
let vector: RArray = vectors.entry(i as isize)?;
|
|
208
|
+
let vec_data = parse_vector(&ruby, vector, self.dim)?;
|
|
209
|
+
|
|
210
|
+
let label = if let Some(ref labels_array) = labels {
|
|
211
|
+
labels_array.entry::<String>(i as isize)?
|
|
212
|
+
} else {
|
|
213
|
+
let mut id = self.current_id.lock().unwrap();
|
|
214
|
+
let label = id.to_string();
|
|
215
|
+
*id += 1;
|
|
216
|
+
label
|
|
217
|
+
};
|
|
218
|
+
|
|
219
|
+
let internal_id = {
|
|
220
|
+
let mut label_map = self.label_to_id.lock().unwrap();
|
|
221
|
+
let mut current_id = self.current_id.lock().unwrap();
|
|
222
|
+
|
|
223
|
+
if label_map.contains_key(&label) {
|
|
224
|
+
return Err(Error::new(
|
|
225
|
+
ruby.exception_arg_error(),
|
|
226
|
+
format!("Label '{}' already exists in index", label)
|
|
227
|
+
));
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
let id = *current_id;
|
|
231
|
+
label_map.insert(label.clone(), id);
|
|
232
|
+
*current_id += 1;
|
|
233
|
+
id
|
|
234
|
+
};
|
|
235
|
+
|
|
236
|
+
data_points.push((vec_data, internal_id));
|
|
237
|
+
metadata_entries.push((internal_id, ItemMetadata {
|
|
238
|
+
label,
|
|
239
|
+
metadata: None,
|
|
240
|
+
}));
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
{
|
|
244
|
+
let mut metadata_store = self.metadata_store.lock().unwrap();
|
|
245
|
+
for (id, metadata) in metadata_entries {
|
|
246
|
+
metadata_store.insert(id, metadata);
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
{
|
|
251
|
+
let hnsw = self.hnsw.lock().unwrap();
|
|
252
|
+
if parallel {
|
|
253
|
+
let data_refs: Vec<(&Vec<f32>, usize)> = data_points.iter().map(|(v, id)| (v, *id)).collect();
|
|
254
|
+
hnsw.parallel_insert(&data_refs);
|
|
255
|
+
} else {
|
|
256
|
+
for (vec, id) in data_points {
|
|
257
|
+
hnsw.insert((&vec, id));
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
Ok(ruby.qnil().as_value())
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
pub fn search(&self, query: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
266
|
+
let ruby = Ruby::get().unwrap();
|
|
267
|
+
|
|
268
|
+
let k: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("k"))? {
|
|
269
|
+
TryConvert::try_convert(v).unwrap_or(10)
|
|
270
|
+
} else {
|
|
271
|
+
10
|
|
272
|
+
};
|
|
273
|
+
|
|
274
|
+
let include_distances: bool = if let Some(v) = kwargs.delete(ruby.to_symbol("include_distances"))? {
|
|
275
|
+
TryConvert::try_convert(v).unwrap_or(false)
|
|
276
|
+
} else {
|
|
277
|
+
false
|
|
278
|
+
};
|
|
279
|
+
|
|
280
|
+
let query_vec = parse_vector(&ruby, query, self.dim)?;
|
|
281
|
+
|
|
282
|
+
if let Some(v) = kwargs.delete(ruby.to_symbol("ef"))? {
|
|
283
|
+
if let Ok(ef) = TryConvert::try_convert(v) as Result<usize, _> {
|
|
284
|
+
let mut ef_search = self.ef_search.lock().unwrap();
|
|
285
|
+
*ef_search = ef;
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
let neighbors = {
|
|
290
|
+
let hnsw = self.hnsw.lock().unwrap();
|
|
291
|
+
let ef_search = self.ef_search.lock().unwrap();
|
|
292
|
+
hnsw.search(&query_vec, k, *ef_search)
|
|
293
|
+
};
|
|
294
|
+
|
|
295
|
+
let metadata_store = self.metadata_store.lock().unwrap();
|
|
296
|
+
|
|
297
|
+
let indices = ruby.ary_new();
|
|
298
|
+
let distances = ruby.ary_new();
|
|
299
|
+
|
|
300
|
+
for neighbor in neighbors {
|
|
301
|
+
if let Some(metadata) = metadata_store.get(&neighbor.d_id) {
|
|
302
|
+
indices.push(ruby.str_new(&metadata.label))?;
|
|
303
|
+
distances.push(ruby.float_from_f64(neighbor.distance as f64))?;
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
if include_distances {
|
|
308
|
+
let result = ruby.ary_new();
|
|
309
|
+
result.push(indices)?;
|
|
310
|
+
result.push(distances)?;
|
|
311
|
+
Ok(result.as_value())
|
|
312
|
+
} else {
|
|
313
|
+
Ok(indices.as_value())
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
pub fn search_with_metadata(&self, query: RArray, kwargs: RHash) -> Result<Value, Error> {
|
|
318
|
+
let ruby = Ruby::get().unwrap();
|
|
319
|
+
|
|
320
|
+
let k: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("k"))? {
|
|
321
|
+
TryConvert::try_convert(v).unwrap_or(10)
|
|
322
|
+
} else {
|
|
323
|
+
10
|
|
324
|
+
};
|
|
325
|
+
|
|
326
|
+
let query_vec = parse_vector(&ruby, query, self.dim)?;
|
|
327
|
+
|
|
328
|
+
let neighbors = {
|
|
329
|
+
let hnsw = self.hnsw.lock().unwrap();
|
|
330
|
+
let ef_search = self.ef_search.lock().unwrap();
|
|
331
|
+
hnsw.search(&query_vec, k, *ef_search)
|
|
332
|
+
};
|
|
333
|
+
|
|
334
|
+
let metadata_store = self.metadata_store.lock().unwrap();
|
|
335
|
+
let results = ruby.ary_new();
|
|
336
|
+
|
|
337
|
+
for neighbor in neighbors {
|
|
338
|
+
if let Some(item_metadata) = metadata_store.get(&neighbor.d_id) {
|
|
339
|
+
let result = ruby.hash_new();
|
|
340
|
+
result.aset(ruby.to_symbol("label"), ruby.str_new(&item_metadata.label))?;
|
|
341
|
+
result.aset(ruby.to_symbol("distance"), ruby.float_from_f64(neighbor.distance as f64))?;
|
|
342
|
+
|
|
343
|
+
let meta_hash = ruby.hash_new();
|
|
344
|
+
if let Some(ref meta) = item_metadata.metadata {
|
|
345
|
+
for (key, value) in meta {
|
|
346
|
+
meta_hash.aset(ruby.str_new(key), ruby.str_new(value))?;
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
result.aset(ruby.to_symbol("metadata"), meta_hash)?;
|
|
350
|
+
|
|
351
|
+
results.push(result)?;
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
Ok(results.as_value())
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
pub fn size(&self) -> Result<usize, Error> {
|
|
359
|
+
let metadata_store = self.metadata_store.lock().unwrap();
|
|
360
|
+
Ok(metadata_store.len())
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
pub fn empty(&self) -> Result<bool, Error> {
|
|
364
|
+
Ok(self.size()? == 0)
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
pub fn set_ef(&self, ef: usize) -> Result<Value, Error> {
|
|
368
|
+
let ruby = Ruby::get().unwrap();
|
|
369
|
+
let mut ef_search = self.ef_search.lock().unwrap();
|
|
370
|
+
*ef_search = ef;
|
|
371
|
+
Ok(ruby.qnil().as_value())
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
pub fn config(&self) -> Result<RHash, Error> {
|
|
375
|
+
let ruby = Ruby::get().unwrap();
|
|
376
|
+
let config = ruby.hash_new();
|
|
377
|
+
config.aset(ruby.to_symbol("dim"), ruby.integer_from_i64(self.dim as i64))?;
|
|
378
|
+
|
|
379
|
+
let space_str = match self.space {
|
|
380
|
+
DistanceType::Euclidean => "euclidean",
|
|
381
|
+
DistanceType::Cosine => "cosine",
|
|
382
|
+
DistanceType::InnerProduct => "inner_product",
|
|
383
|
+
};
|
|
384
|
+
config.aset(ruby.to_symbol("space"), ruby.str_new(space_str))?;
|
|
385
|
+
|
|
386
|
+
let ef_search = self.ef_search.lock().unwrap();
|
|
387
|
+
config.aset(ruby.to_symbol("ef"), ruby.integer_from_i64(*ef_search as i64))?;
|
|
388
|
+
config.aset(ruby.to_symbol("size"), ruby.integer_from_i64(self.size()? as i64))?;
|
|
389
|
+
|
|
390
|
+
Ok(config)
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
pub fn stats(&self) -> Result<RHash, Error> {
|
|
394
|
+
let ruby = Ruby::get().unwrap();
|
|
395
|
+
let stats = ruby.hash_new();
|
|
396
|
+
|
|
397
|
+
stats.aset(ruby.to_symbol("size"), ruby.integer_from_i64(self.size()? as i64))?;
|
|
398
|
+
stats.aset(ruby.to_symbol("dim"), ruby.integer_from_i64(self.dim as i64))?;
|
|
399
|
+
|
|
400
|
+
let ef_search = self.ef_search.lock().unwrap();
|
|
401
|
+
stats.aset(ruby.to_symbol("ef_search"), ruby.integer_from_i64(*ef_search as i64))?;
|
|
402
|
+
|
|
403
|
+
Ok(stats)
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
pub fn load(path: RString) -> Result<Self, Error> {
|
|
407
|
+
let ruby = Ruby::get().unwrap();
|
|
408
|
+
let path_str = path.to_string()?;
|
|
409
|
+
|
|
410
|
+
let metadata_path = format!("{}.metadata", path_str);
|
|
411
|
+
let metadata_file = File::open(&metadata_path)
|
|
412
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to open metadata file: {}", e)))?;
|
|
413
|
+
|
|
414
|
+
let (
|
|
415
|
+
_metadata_store,
|
|
416
|
+
_label_to_id,
|
|
417
|
+
_current_id,
|
|
418
|
+
_dim,
|
|
419
|
+
_space_str,
|
|
420
|
+
): (
|
|
421
|
+
HashMap<usize, ItemMetadata>,
|
|
422
|
+
HashMap<String, usize>,
|
|
423
|
+
usize,
|
|
424
|
+
usize,
|
|
425
|
+
String,
|
|
426
|
+
) = bincode::deserialize_from(metadata_file)
|
|
427
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to load metadata: {}", e)))?;
|
|
428
|
+
|
|
429
|
+
let hnsw_dir = format!("{}_hnsw_data", path_str);
|
|
430
|
+
let hnsw_path = std::path::Path::new(&hnsw_dir);
|
|
431
|
+
|
|
432
|
+
let hnswio = Box::new(HnswIo::new(hnsw_path, "hnsw"));
|
|
433
|
+
let hnswio_static: &'static mut HnswIo = Box::leak(hnswio);
|
|
434
|
+
|
|
435
|
+
let hnsw: Hnsw<'static, f32, DistL2> = hnswio_static.load_hnsw()
|
|
436
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to load HNSW index: {}", e)))?;
|
|
437
|
+
|
|
438
|
+
let metadata_store = _metadata_store;
|
|
439
|
+
let label_to_id = _label_to_id;
|
|
440
|
+
let current_id = _current_id;
|
|
441
|
+
let dim = _dim;
|
|
442
|
+
let space = match _space_str.as_str() {
|
|
443
|
+
"euclidean" => DistanceType::Euclidean,
|
|
444
|
+
"cosine" => DistanceType::Cosine,
|
|
445
|
+
"inner_product" => DistanceType::InnerProduct,
|
|
446
|
+
_ => return Err(Error::new(ruby.exception_runtime_error(), "Unknown distance type in saved file")),
|
|
447
|
+
};
|
|
448
|
+
|
|
449
|
+
let ef_search = 200;
|
|
450
|
+
|
|
451
|
+
Ok(Self {
|
|
452
|
+
hnsw: Arc::new(Mutex::new(hnsw)),
|
|
453
|
+
dim,
|
|
454
|
+
space,
|
|
455
|
+
metadata_store: Arc::new(Mutex::new(metadata_store)),
|
|
456
|
+
current_id: Arc::new(Mutex::new(current_id)),
|
|
457
|
+
label_to_id: Arc::new(Mutex::new(label_to_id)),
|
|
458
|
+
ef_search: Arc::new(Mutex::new(ef_search)),
|
|
459
|
+
})
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
pub fn save(&self, path: RString) -> Result<Value, Error> {
|
|
463
|
+
let ruby = Ruby::get().unwrap();
|
|
464
|
+
let path_str = path.to_string()?;
|
|
465
|
+
|
|
466
|
+
let hnsw_dir = format!("{}_hnsw_data", path_str);
|
|
467
|
+
std::fs::create_dir_all(&hnsw_dir)
|
|
468
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to create directory: {}", e)))?;
|
|
469
|
+
|
|
470
|
+
{
|
|
471
|
+
let hnsw = self.hnsw.lock().unwrap();
|
|
472
|
+
hnsw.file_dump(&std::path::Path::new(&hnsw_dir), "hnsw")
|
|
473
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to save HNSW: {}", e)))?;
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
let metadata_path = format!("{}.metadata", path_str);
|
|
477
|
+
{
|
|
478
|
+
let metadata_store = self.metadata_store.lock().unwrap();
|
|
479
|
+
let label_to_id = self.label_to_id.lock().unwrap();
|
|
480
|
+
let current_id = self.current_id.lock().unwrap();
|
|
481
|
+
|
|
482
|
+
let metadata_data = (
|
|
483
|
+
&*metadata_store,
|
|
484
|
+
&*label_to_id,
|
|
485
|
+
*current_id,
|
|
486
|
+
self.dim,
|
|
487
|
+
match self.space {
|
|
488
|
+
DistanceType::Euclidean => "euclidean",
|
|
489
|
+
DistanceType::Cosine => "cosine",
|
|
490
|
+
DistanceType::InnerProduct => "inner_product",
|
|
491
|
+
},
|
|
492
|
+
);
|
|
493
|
+
|
|
494
|
+
let file = File::create(&metadata_path)
|
|
495
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to create metadata file: {}", e)))?;
|
|
496
|
+
|
|
497
|
+
bincode::serialize_into(file, &metadata_data)
|
|
498
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to save metadata: {}", e)))?;
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
Ok(ruby.qnil().as_value())
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
// Helper function to parse a Ruby array into a Vec<f32>
|
|
506
|
+
fn parse_vector(ruby: &Ruby, array: RArray, expected_dim: usize) -> Result<Vec<f32>, Error> {
|
|
507
|
+
let len = array.len();
|
|
508
|
+
if len != expected_dim {
|
|
509
|
+
return Err(Error::new(
|
|
510
|
+
ruby.exception_arg_error(),
|
|
511
|
+
format!("Vector dimension mismatch: expected {}, got {}", expected_dim, len)
|
|
512
|
+
));
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
let mut vec = Vec::with_capacity(len);
|
|
516
|
+
for i in 0..len {
|
|
517
|
+
let value: f64 = array.entry(i as isize)?;
|
|
518
|
+
vec.push(value as f32);
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
Ok(vec)
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
// Helper function to parse metadata
|
|
525
|
+
fn parse_metadata(ruby: &Ruby, value: Value) -> Result<HashMap<String, String>, Error> {
|
|
526
|
+
let hash: RHash = TryConvert::try_convert(value)
|
|
527
|
+
.map_err(|_| Error::new(ruby.exception_type_error(), "Metadata must be a hash"))?;
|
|
528
|
+
|
|
529
|
+
let mut metadata = HashMap::new();
|
|
530
|
+
|
|
531
|
+
hash.foreach(|key: Value, value: Value| {
|
|
532
|
+
let ruby = Ruby::get().unwrap();
|
|
533
|
+
|
|
534
|
+
let key_str = if let Ok(s) = String::try_convert(key) {
|
|
535
|
+
s
|
|
536
|
+
} else if let Ok(sym) = Symbol::try_convert(key) {
|
|
537
|
+
sym.name()?.to_string()
|
|
538
|
+
} else {
|
|
539
|
+
return Err(Error::new(ruby.exception_type_error(), "Metadata keys must be strings or symbols"));
|
|
540
|
+
};
|
|
541
|
+
|
|
542
|
+
let value_str = if let Ok(s) = String::try_convert(value) {
|
|
543
|
+
s
|
|
544
|
+
} else if let Ok(i) = Integer::try_convert(value) {
|
|
545
|
+
i.to_string()
|
|
546
|
+
} else if let Ok(f) = Float::try_convert(value) {
|
|
547
|
+
f.to_f64().to_string()
|
|
548
|
+
} else {
|
|
549
|
+
let to_s_method = value.funcall::<_, _, RString>("to_s", ())?;
|
|
550
|
+
to_s_method.to_string()?
|
|
551
|
+
};
|
|
552
|
+
|
|
553
|
+
metadata.insert(key_str, value_str);
|
|
554
|
+
Ok(ForEach::Continue)
|
|
555
|
+
})?;
|
|
556
|
+
|
|
557
|
+
Ok(metadata)
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
// Initialize the HNSW module
|
|
561
|
+
pub fn init(parent: &magnus::RModule) -> Result<(), Error> {
|
|
562
|
+
let ruby = Ruby::get().unwrap();
|
|
563
|
+
let class = parent.define_class("HNSW", ruby.class_object())?;
|
|
564
|
+
|
|
565
|
+
class.define_singleton_method("new", function!(HnswIndex::new, 1))?;
|
|
566
|
+
class.define_singleton_method("load", function!(HnswIndex::load, 1))?;
|
|
567
|
+
class.define_method("add_item", method!(HnswIndex::add_item, 2))?;
|
|
568
|
+
class.define_method("add_batch", method!(HnswIndex::add_batch, 2))?;
|
|
569
|
+
class.define_method("search", method!(HnswIndex::search, 2))?;
|
|
570
|
+
class.define_method("search_with_metadata", method!(HnswIndex::search_with_metadata, 2))?;
|
|
571
|
+
class.define_method("size", method!(HnswIndex::size, 0))?;
|
|
572
|
+
class.define_method("empty?", method!(HnswIndex::empty, 0))?;
|
|
573
|
+
class.define_method("set_ef", method!(HnswIndex::set_ef, 1))?;
|
|
574
|
+
class.define_method("config", method!(HnswIndex::config, 0))?;
|
|
575
|
+
class.define_method("stats", method!(HnswIndex::stats, 0))?;
|
|
576
|
+
class.define_method("save", method!(HnswIndex::save, 1))?;
|
|
577
|
+
|
|
578
|
+
Ok(())
|
|
579
|
+
}
|
data/ext/clusterkit/src/lib.rs
CHANGED
|
@@ -1,22 +1,24 @@
|
|
|
1
|
-
use magnus::{
|
|
1
|
+
use magnus::{Error, Ruby};
|
|
2
2
|
|
|
3
3
|
mod embedder;
|
|
4
4
|
mod svd;
|
|
5
5
|
mod utils;
|
|
6
6
|
mod clustering;
|
|
7
|
+
mod hnsw;
|
|
7
8
|
|
|
8
9
|
#[cfg(test)]
|
|
9
10
|
mod tests;
|
|
10
11
|
|
|
11
12
|
#[magnus::init]
|
|
12
|
-
fn init() -> Result<(), Error> {
|
|
13
|
-
let module = define_module("ClusterKit")?;
|
|
14
|
-
|
|
13
|
+
fn init(ruby: &Ruby) -> Result<(), Error> {
|
|
14
|
+
let module = ruby.define_module("ClusterKit")?;
|
|
15
|
+
|
|
15
16
|
// Initialize submodules
|
|
16
17
|
embedder::init(&module)?;
|
|
17
18
|
svd::init(&module)?;
|
|
18
19
|
utils::init(&module)?;
|
|
19
20
|
clustering::init(&module)?;
|
|
20
|
-
|
|
21
|
+
hnsw::init(&module)?;
|
|
22
|
+
|
|
21
23
|
Ok(())
|
|
22
24
|
}
|