hnswlib 0.7.0 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
4
- data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
3
+ metadata.gz: 33dd2f9cd8656dfe469b4b9baf319ae4e0b78826dec58da8ae78516077a6fbc2
4
+ data.tar.gz: ad32e54a325136587ae12716243b88ed1269c7bc9cd4e92e9aa8896b2517d4d2
5
5
  SHA512:
6
- metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
7
- data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
6
+ metadata.gz: 361822ebf216d8f8ba3abf8314e52b1a72887fb2a03aea29572940376326b128442fc007fa7de6a6e925facb0b9295076fb10a98ff053aad7e0fbe41e4973c2a
7
+ data.tar.gz: 38f9f4b10665d9e87940b23fae51a15640229367f6bb74a7dba176bab6e798c8fad8abce106c07e38e0acc9046949f80208f9f9b30e3341aceb5a6193dac3aca
data/CHANGELOG.md CHANGED
@@ -1,3 +1,29 @@
1
+ ## [0.8.1] - 2023-03-18
2
+
3
+ - Update the type declarations of HierarchicalNSW and BruteforceSearch along with recent changes.
4
+
5
+ ## [0.8.0] - 2023-03-14
6
+
7
+ **Breaking change:**
8
+
9
+ - Change to give a String to the space argument of the `initialize` method
10
+ in [HierarchicalNSW](https://yoshoku.github.io/hnswlib.rb/doc/Hnswlib/HierarchicalNSW.html) and [BruteforceSearch](https://yoshoku.github.io/hnswlib.rb/doc/Hnswlib/BruteforceSearch.html).
11
+ - Add `init_index` method to HierarchicalNSW and BruteforceSearch.
12
+ Along with this, some arguments of `initialize` method moved to `init_index` method.
13
+ ```ruby
14
+ require 'hnswlib'
15
+
16
+ n_features = 3
17
+ max_elements = 10
18
+
19
+ hnsw = Hnswlib::HierarchicalNSW.new(space: 'l2', dim: n_features)
20
+ hnsw.init_index(max_elements: max_elements, m: 16, ef_construction: 200, random_seed: 42, allow_replace_deleted: false)
21
+
22
+ bf = Hnswlib::BruteforceSearch.new(space: 'l2', dim: n_features)
23
+ bf.init_index(max_elements: max_elements)
24
+ ```
25
+ - Deprecate [HnswIndex](https://yoshoku.github.io/hnswlib.rb/doc/Hnswlib/HnswIndex.html) has interface similar to Annoy.
26
+
1
27
  ## [0.7.0] - 2023-03-04
2
28
 
3
29
  - Update bundled hnswlib version to 0.7.0.
data/README.md CHANGED
@@ -48,19 +48,26 @@ $ gem install hnswlib -- --with-cxxflags=-march=native
48
48
  ```ruby
49
49
  require 'hnswlib'
50
50
 
51
- f = 40 # length of item vector that will be indexed.
52
- t = Hnswlib::HnswIndex.new(n_features: f, max_item: 1000)
51
+ f = 40 # length of datum point vector that will be indexed.
52
+ t = Hnswlib::HierarchicalNSW.new(space: 'l2', dim: f)
53
+ t.init_index(max_elements: 1000)
53
54
 
54
55
  1000.times do |i|
55
56
  v = Array.new(f) { rand }
56
- t.add_item(i, v)
57
+ t.add_point(v, i)
57
58
  end
58
59
 
59
- t.save('test.ann')
60
+ t.save_index('test.ann')
61
+ ```
62
+
63
+ ```ruby
64
+ require 'hnswlib'
65
+
66
+ u = Hnswlib::HierarchicalNSW.new(space: 'l2', dim: f)
67
+ u.load_index('test.ann')
60
68
 
61
- u = Hnswlib::HnswIndex.new(n_features: f, max_item: 1000)
62
- u.load('test.ann')
63
- p u.get_nns_by_item(0, 100) # will find the 100 nearest neighbors.
69
+ q = Array.new(f) { rand }
70
+ p u.search_knn(q, 100) # will find the 100 nearest neighbors.
64
71
  ```
65
72
 
66
73
  ## License
@@ -18,8 +18,6 @@
18
18
 
19
19
  #include "hnswlibext.hpp"
20
20
 
21
- VALUE rb_mHnswlib;
22
-
23
21
  extern "C" void Init_hnswlibext(void) {
24
22
  rb_mHnswlib = rb_define_module("Hnswlib");
25
23
  RbHnswlibL2Space::define_class(rb_mHnswlib);
@@ -23,9 +23,11 @@
23
23
 
24
24
  #include <hnswlib.h>
25
25
 
26
+ #include <cmath>
26
27
  #include <new>
27
28
  #include <vector>
28
29
 
30
+ VALUE rb_mHnswlib;
29
31
  VALUE rb_cHnswlibL2Space;
30
32
  VALUE rb_cHnswlibInnerProductSpace;
31
33
  VALUE rb_cHnswlibHierarchicalNSW;
@@ -52,8 +54,8 @@ public:
52
54
  return ptr;
53
55
  };
54
56
 
55
- static VALUE define_class(VALUE rb_mHnswlib) {
56
- rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
57
+ static VALUE define_class(VALUE outer) {
58
+ rb_cHnswlibL2Space = rb_define_class_under(outer, "L2Space", rb_cObject);
57
59
  rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
58
60
  rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
59
61
  rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
@@ -128,8 +130,8 @@ public:
128
130
  return ptr;
129
131
  };
130
132
 
131
- static VALUE define_class(VALUE rb_mHnswlib) {
132
- rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
133
+ static VALUE define_class(VALUE outer) {
134
+ rb_cHnswlibInnerProductSpace = rb_define_class_under(outer, "InnerProductSpace", rb_cObject);
133
135
  rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
134
136
  rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
135
137
  rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
@@ -218,10 +220,11 @@ public:
218
220
  return ptr;
219
221
  };
220
222
 
221
- static VALUE define_class(VALUE rb_mHnswlib) {
222
- rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
223
+ static VALUE define_class(VALUE outer) {
224
+ rb_cHnswlibHierarchicalNSW = rb_define_class_under(outer, "HierarchicalNSW", rb_cObject);
223
225
  rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
224
- rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
226
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_initialize), -1);
227
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "init_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init_index), -1);
225
228
  rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
226
229
  rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
227
230
  rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
@@ -244,59 +247,90 @@ public:
244
247
  private:
245
248
  static const rb_data_type_t hnsw_hierarchicalnsw_type;
246
249
 
247
- static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
250
+ static VALUE _hnsw_hierarchicalnsw_initialize(int argc, VALUE* argv, VALUE self) {
248
251
  VALUE kw_args = Qnil;
249
- ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
250
- rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
251
- VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
252
+ ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
253
+ VALUE kw_values[2] = {Qundef, Qundef};
252
254
  rb_scan_args(argc, argv, ":", &kw_args);
253
- rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
254
- if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
255
- if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
256
- if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
257
- if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
255
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
258
256
 
259
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
260
- rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
261
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
257
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
258
+ rb_raise(rb_eTypeError, "expected space, String");
259
+ return Qnil;
260
+ }
261
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
262
+ strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
263
+ rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
262
264
  return Qnil;
263
265
  }
264
266
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
267
+ rb_raise(rb_eTypeError, "expected dim, Integer");
268
+ return Qnil;
269
+ }
270
+
271
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
272
+ rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
273
+ } else {
274
+ rb_iv_set(self, "@space",
275
+ rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
276
+ }
277
+
278
+ rb_iv_set(self, "@normalize", Qfalse);
279
+ if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
280
+
281
+ return Qnil;
282
+ };
283
+
284
+ static VALUE _hnsw_hierarchicalnsw_init_index(int argc, VALUE* argv, VALUE self) {
285
+ VALUE kw_args = Qnil;
286
+ ID kw_table[5] = {rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"), rb_intern("random_seed"),
287
+ rb_intern("allow_replace_deleted")};
288
+ VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
289
+ rb_scan_args(argc, argv, ":", &kw_args);
290
+ rb_get_kwargs(kw_args, kw_table, 1, 4, kw_values);
291
+ if (kw_values[1] == Qundef) kw_values[1] = SIZET2NUM(16);
292
+ if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(200);
293
+ if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(100);
294
+ if (kw_values[4] == Qundef) kw_values[4] = Qfalse;
295
+
296
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
265
297
  rb_raise(rb_eTypeError, "expected max_elements, Integer");
266
298
  return Qnil;
267
299
  }
268
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
300
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
269
301
  rb_raise(rb_eTypeError, "expected m, Integer");
270
302
  return Qnil;
271
303
  }
272
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
304
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
273
305
  rb_raise(rb_eTypeError, "expected ef_construction, Integer");
274
306
  return Qnil;
275
307
  }
276
- if (!RB_INTEGER_TYPE_P(kw_values[4])) {
308
+ if (!RB_INTEGER_TYPE_P(kw_values[3])) {
277
309
  rb_raise(rb_eTypeError, "expected random_seed, Integer");
278
310
  return Qnil;
279
311
  }
280
- if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
312
+ if (!RB_TYPE_P(kw_values[4], T_TRUE) && !RB_TYPE_P(kw_values[4], T_FALSE)) {
281
313
  rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
282
314
  return Qnil;
283
315
  }
284
316
 
285
- rb_iv_set(self, "@space", kw_values[0]);
286
- hnswlib::SpaceInterface<float>* space;
287
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
288
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
317
+ hnswlib::SpaceInterface<float>* space = nullptr;
318
+ VALUE ivspace = rb_iv_get(self, "@space");
319
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
320
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
289
321
  } else {
290
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
322
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
291
323
  }
292
- const size_t max_elements = NUM2SIZET(kw_values[1]);
293
- const size_t m = NUM2SIZET(kw_values[2]);
294
- const size_t ef_construction = NUM2SIZET(kw_values[3]);
295
- const size_t random_seed = NUM2SIZET(kw_values[4]);
296
- const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
324
+
325
+ const size_t max_elements = NUM2SIZET(kw_values[0]);
326
+ const size_t m = NUM2SIZET(kw_values[1]);
327
+ const size_t ef_construction = NUM2SIZET(kw_values[2]);
328
+ const size_t random_seed = NUM2SIZET(kw_values[3]);
329
+ const bool allow_replace_deleted = kw_values[4] == Qtrue ? true : false;
297
330
 
298
331
  hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
299
332
  try {
333
+ ptr->~HierarchicalNSW();
300
334
  new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
301
335
  } catch (const std::runtime_error& e) {
302
336
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -335,20 +369,29 @@ private:
335
369
  return Qfalse;
336
370
  }
337
371
 
338
- float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
339
- for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
372
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
373
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
340
374
  const size_t idx = NUM2SIZET(_idx);
341
375
  const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
342
376
 
377
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
378
+ float norm = 0.0;
379
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
380
+ norm = std::sqrt(std::fabs(norm));
381
+ if (norm >= 0.0) {
382
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
383
+ }
384
+ }
385
+
343
386
  try {
344
- get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
387
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, idx, replace_deleted);
345
388
  } catch (const std::runtime_error& e) {
346
- ruby_xfree(arr);
389
+ ruby_xfree(vec);
347
390
  rb_raise(rb_eRuntimeError, "%s", e.what());
348
391
  return Qfalse;
349
392
  }
350
393
 
351
- ruby_xfree(arr);
394
+ ruby_xfree(vec);
352
395
  return Qtrue;
353
396
  };
354
397
 
@@ -392,6 +435,15 @@ private:
392
435
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
393
436
  }
394
437
 
438
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
439
+ float norm = 0.0;
440
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
441
+ norm = std::sqrt(std::fabs(norm));
442
+ if (norm >= 0.0) {
443
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
444
+ }
445
+ }
446
+
395
447
  std::priority_queue<std::pair<float, size_t>> result;
396
448
  try {
397
449
  result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
@@ -607,10 +659,11 @@ public:
607
659
  return ptr;
608
660
  };
609
661
 
610
- static VALUE define_class(VALUE rb_mHnswlib) {
611
- rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
662
+ static VALUE define_class(VALUE outer) {
663
+ rb_cHnswlibBruteforceSearch = rb_define_class_under(outer, "BruteforceSearch", rb_cObject);
612
664
  rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
613
- rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
665
+ rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_initialize), -1);
666
+ rb_define_method(rb_cHnswlibBruteforceSearch, "init_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init_index), -1);
614
667
  rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
615
668
  rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
616
669
  rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
@@ -625,34 +678,66 @@ public:
625
678
  private:
626
679
  static const rb_data_type_t hnsw_bruteforcesearch_type;
627
680
 
628
- static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
681
+ static VALUE _hnsw_bruteforcesearch_initialize(int argc, VALUE* argv, VALUE self) {
629
682
  VALUE kw_args = Qnil;
630
- ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
683
+ ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
631
684
  VALUE kw_values[2] = {Qundef, Qundef};
632
685
  rb_scan_args(argc, argv, ":", &kw_args);
633
686
  rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
634
687
 
635
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
636
- rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
637
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
688
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
689
+ rb_raise(rb_eTypeError, "expected space, String");
690
+ return Qnil;
691
+ }
692
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
693
+ strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
694
+ rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
638
695
  return Qnil;
639
696
  }
640
697
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
641
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
698
+ rb_raise(rb_eTypeError, "expected dim, Integer");
642
699
  return Qnil;
643
700
  }
644
701
 
645
- rb_iv_set(self, "@space", kw_values[0]);
646
702
  hnswlib::SpaceInterface<float>* space;
647
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
648
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
703
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
704
+ rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
705
+ } else {
706
+ rb_iv_set(self, "@space",
707
+ rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
708
+ }
709
+
710
+ rb_iv_set(self, "@normalize", Qfalse);
711
+ if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
712
+
713
+ return Qnil;
714
+ };
715
+
716
+ static VALUE _hnsw_bruteforcesearch_init_index(int argc, VALUE* argv, VALUE self) {
717
+ VALUE kw_args = Qnil;
718
+ ID kw_table[1] = {rb_intern("max_elements")};
719
+ VALUE kw_values[1] = {Qundef};
720
+ rb_scan_args(argc, argv, ":", &kw_args);
721
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
722
+
723
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
724
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
725
+ return Qnil;
726
+ }
727
+
728
+ hnswlib::SpaceInterface<float>* space = nullptr;
729
+ VALUE ivspace = rb_iv_get(self, "@space");
730
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
731
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
649
732
  } else {
650
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
733
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
651
734
  }
652
- const size_t max_elements = NUM2SIZET(kw_values[1]);
735
+
736
+ const size_t max_elements = NUM2SIZET(kw_values[0]);
653
737
 
654
738
  hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
655
739
  try {
740
+ ptr->~BruteforceSearch();
656
741
  new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
657
742
  } catch (const std::runtime_error& e) {
658
743
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -681,6 +766,15 @@ private:
681
766
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
682
767
  for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
683
768
 
769
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
770
+ float norm = 0.0;
771
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
772
+ norm = std::sqrt(std::fabs(norm));
773
+ if (norm >= 0.0) {
774
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
775
+ }
776
+ }
777
+
684
778
  try {
685
779
  get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
686
780
  } catch (const std::runtime_error& e) {
@@ -729,8 +823,15 @@ private:
729
823
  }
730
824
 
731
825
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
732
- for (size_t i = 0; i < dim; i++) {
733
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
826
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
827
+
828
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
829
+ float norm = 0.0;
830
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
831
+ norm = std::sqrt(std::fabs(norm));
832
+ if (norm >= 0.0) {
833
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
834
+ }
734
835
  }
735
836
 
736
837
  std::priority_queue<std::pair<float, size_t>> result =
@@ -3,7 +3,7 @@
3
3
  # Hnswlib.rb provides Ruby bindings for the Hnswlib.
4
4
  module Hnswlib
5
5
  # The version of Hnswlib.rb you install.
6
- VERSION = '0.7.0'
6
+ VERSION = '0.8.1'
7
7
 
8
8
  # The version of Hnswlib included with gem.
9
9
  HSWLIB_VERSION = '0.7.0'
data/lib/hnswlib.rb CHANGED
@@ -18,6 +18,7 @@ module Hnswlib
18
18
  #
19
19
  # index.get_nns_by_item(0, 100)
20
20
  #
21
+ # @deprecated This class was prepared as a class with an interface similar to Annoy, but it is not very useful and will be deleted in the next version.
21
22
  class HnswIndex
22
23
  # Returns the metric of index.
23
24
  # @return [String]
@@ -27,7 +28,7 @@ module Hnswlib
27
28
  #
28
29
  # @param n_features [Integer] The number of features (dimensions) of stored vector.
29
30
  # @param max_item [Integer] The maximum number of items.
30
- # @param metric [String] The distance metric between vectors ('l2' or 'dot').
31
+ # @param metric [String] The distance metric between vectors ('l2', 'dot', or 'cosine').
31
32
  # @param m [Integer] The maximum number of outgoing connections in the graph
32
33
  # @param ef_construction [Integer] The size of the dynamic list for the nearest neighbors. It controls the index time/accuracy trade-off.
33
34
  # @param random_seed [Integer] The seed value using to initialize the random generator.
@@ -35,15 +36,10 @@ module Hnswlib
35
36
  def initialize(n_features:, max_item:, metric: 'l2', m: 16, ef_construction: 200,
36
37
  random_seed: 100, allow_replace_removed: false)
37
38
  @metric = metric
38
- space = if @metric == 'dot'
39
- Hnswlib::InnerProductSpace.new(n_features)
40
- else
41
- Hnswlib::L2Space.new(n_features)
42
- end
43
- @index = Hnswlib::HierarchicalNSW.new(
44
- space: space, max_elements: max_item, m: m, ef_construction: ef_construction,
45
- random_seed: random_seed, allow_replace_deleted: allow_replace_removed
46
- )
39
+ space = @metric == 'dot' ? 'ip' : 'l2'
40
+ @index = Hnswlib::HierarchicalNSW.new(space: space, dim: n_features)
41
+ @index.init_index(max_elements: max_item, m: m, ef_construction: ef_construction,
42
+ random_seed: random_seed, allow_replace_deleted: allow_replace_removed)
47
43
  end
48
44
 
49
45
  # Add item to be indexed.
data/sig/hnswlib.rbs CHANGED
@@ -40,30 +40,36 @@ module Hnswlib
40
40
  class BruteforceSearch
41
41
  attr_accessor space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace)
42
42
 
43
- def initialize: (space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace) space, max_elements: Integer max_elements) -> void
43
+ def initialize: (space: String space, dim: Integer dim) -> void
44
+ def init_index: (max_elements: Integer max_elements) -> void
44
45
  def add_point: (Array[Float] arr, Integer idx) -> bool
45
46
  def current_count: () -> Integer
46
47
  def load_index: (String filename) -> void
47
48
  def max_elements: () -> Integer
48
49
  def remove_point: (Integer idx) -> void
49
50
  def save_index: (String filename) -> void
50
- def search_knn: (Array[Float] arr, Integer k) -> [Array[Integer], Array[Float]]
51
+ def search_knn: (Array[Float] arr, Integer k, ?filter: Proc filter) -> [Array[Integer], Array[Float]]
51
52
  end
52
53
 
53
54
  class HierarchicalNSW
54
55
  attr_accessor space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace)
55
56
 
56
- def initialize: (space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace) space, max_elements: Integer max_elements, ?m: Integer m, ?ef_construction: Integer ef_construction, ?random_seed: Integer random_seed, ?allow_replace_deleted: (true | false) allow_replace_deleted) -> void
57
+ def initialize: (space: String space, dim: Integer dim) -> void
58
+ def init_index: (max_elements: Integer max_elements, ?m: Integer m, ?ef_construction: Integer ef_construction, ?random_seed: Integer random_seed, ?allow_replace_deleted: (true | false) allow_replace_deleted) -> void
57
59
  def add_point: (Array[Float] arr, Integer idx, ?replace_deleted: (true | false) replace_deleted) -> bool
58
60
  def current_count: () -> Integer
59
61
  def get_ids: () -> Array[Integer]
60
62
  def get_point: (Integer idx) -> Array[Float]
61
63
  def load_index: (String filename, ?allow_replace_deleted: (true | false) allow_replace_deleted) -> void
62
64
  def mark_deleted: (Integer idx) -> void
65
+ def unmark_deleted: (Integer idx) -> void
63
66
  def max_elements: () -> Integer
64
67
  def resize_index: (Integer new_max_elements) -> void
65
68
  def save_index: (String filename) -> void
66
- def search_knn: (Array[Float] arr, Integer k) -> [Array[Integer], Array[Float]]
69
+ def search_knn: (Array[Float] arr, Integer k, ?filter: Proc filter) -> [Array[Integer], Array[Float]]
67
70
  def set_ef: (Integer ef) -> void
71
+ def get_ef: () -> Integer
72
+ def ef_construction: () -> Integer
73
+ def m: () -> Integer
68
74
  end
69
75
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: hnswlib
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.7.0
4
+ version: 0.8.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-04 00:00:00.000000000 Z
11
+ date: 2023-03-18 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: Hnswlib.rb provides Ruby bindings for the Hnswlib.
14
14
  email: