hnswlib 0.7.0 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
4
- data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
3
+ metadata.gz: 33dd2f9cd8656dfe469b4b9baf319ae4e0b78826dec58da8ae78516077a6fbc2
4
+ data.tar.gz: ad32e54a325136587ae12716243b88ed1269c7bc9cd4e92e9aa8896b2517d4d2
5
5
  SHA512:
6
- metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
7
- data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
6
+ metadata.gz: 361822ebf216d8f8ba3abf8314e52b1a72887fb2a03aea29572940376326b128442fc007fa7de6a6e925facb0b9295076fb10a98ff053aad7e0fbe41e4973c2a
7
+ data.tar.gz: 38f9f4b10665d9e87940b23fae51a15640229367f6bb74a7dba176bab6e798c8fad8abce106c07e38e0acc9046949f80208f9f9b30e3341aceb5a6193dac3aca
data/CHANGELOG.md CHANGED
@@ -1,3 +1,29 @@
1
+ ## [0.8.1] - 2023-03-18
2
+
3
+ - Update the type declarations of HierarchicalNSW and BruteforceSearch along with recent changes.
4
+
5
+ ## [0.8.0] - 2023-03-14
6
+
7
+ **Breaking change:**
8
+
9
+ - Change to give a String to the space argument of the `initialize` method
10
+ in [HierarchicalNSW](https://yoshoku.github.io/hnswlib.rb/doc/Hnswlib/HierarchicalNSW.html) and [BruteforceSearch](https://yoshoku.github.io/hnswlib.rb/doc/Hnswlib/BruteforceSearch.html).
11
+ - Add `init_index` method to HierarchicalNSW and BruteforceSearch.
12
+ Along with this, some arguments of `initialize` method moved to `init_index` method.
13
+ ```ruby
14
+ require 'hnswlib'
15
+
16
+ n_features = 3
17
+ max_elements = 10
18
+
19
+ hnsw = Hnswlib::HierarchicalNSW.new(space: 'l2', dim: n_features)
20
+ hnsw.init_index(max_elements: max_elements, m: 16, ef_construction: 200, random_seed: 42, allow_replace_deleted: false)
21
+
22
+ bf = Hnswlib::BruteforceSearch.new(space: 'l2', dim: n_features)
23
+ bf.init_index(max_elements: max_elements)
24
+ ```
25
+ - Deprecate [HnswIndex](https://yoshoku.github.io/hnswlib.rb/doc/Hnswlib/HnswIndex.html) has interface similar to Annoy.
26
+
1
27
  ## [0.7.0] - 2023-03-04
2
28
 
3
29
  - Update bundled hnswlib version to 0.7.0.
data/README.md CHANGED
@@ -48,19 +48,26 @@ $ gem install hnswlib -- --with-cxxflags=-march=native
48
48
  ```ruby
49
49
  require 'hnswlib'
50
50
 
51
- f = 40 # length of item vector that will be indexed.
52
- t = Hnswlib::HnswIndex.new(n_features: f, max_item: 1000)
51
+ f = 40 # length of datum point vector that will be indexed.
52
+ t = Hnswlib::HierarchicalNSW.new(space: 'l2', dim: f)
53
+ t.init_index(max_elements: 1000)
53
54
 
54
55
  1000.times do |i|
55
56
  v = Array.new(f) { rand }
56
- t.add_item(i, v)
57
+ t.add_point(v, i)
57
58
  end
58
59
 
59
- t.save('test.ann')
60
+ t.save_index('test.ann')
61
+ ```
62
+
63
+ ```ruby
64
+ require 'hnswlib'
65
+
66
+ u = Hnswlib::HierarchicalNSW.new(space: 'l2', dim: f)
67
+ u.load_index('test.ann')
60
68
 
61
- u = Hnswlib::HnswIndex.new(n_features: f, max_item: 1000)
62
- u.load('test.ann')
63
- p u.get_nns_by_item(0, 100) # will find the 100 nearest neighbors.
69
+ q = Array.new(f) { rand }
70
+ p u.search_knn(q, 100) # will find the 100 nearest neighbors.
64
71
  ```
65
72
 
66
73
  ## License
@@ -18,8 +18,6 @@
18
18
 
19
19
  #include "hnswlibext.hpp"
20
20
 
21
- VALUE rb_mHnswlib;
22
-
23
21
  extern "C" void Init_hnswlibext(void) {
24
22
  rb_mHnswlib = rb_define_module("Hnswlib");
25
23
  RbHnswlibL2Space::define_class(rb_mHnswlib);
@@ -23,9 +23,11 @@
23
23
 
24
24
  #include <hnswlib.h>
25
25
 
26
+ #include <cmath>
26
27
  #include <new>
27
28
  #include <vector>
28
29
 
30
+ VALUE rb_mHnswlib;
29
31
  VALUE rb_cHnswlibL2Space;
30
32
  VALUE rb_cHnswlibInnerProductSpace;
31
33
  VALUE rb_cHnswlibHierarchicalNSW;
@@ -52,8 +54,8 @@ public:
52
54
  return ptr;
53
55
  };
54
56
 
55
- static VALUE define_class(VALUE rb_mHnswlib) {
56
- rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
57
+ static VALUE define_class(VALUE outer) {
58
+ rb_cHnswlibL2Space = rb_define_class_under(outer, "L2Space", rb_cObject);
57
59
  rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
58
60
  rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
59
61
  rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
@@ -128,8 +130,8 @@ public:
128
130
  return ptr;
129
131
  };
130
132
 
131
- static VALUE define_class(VALUE rb_mHnswlib) {
132
- rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
133
+ static VALUE define_class(VALUE outer) {
134
+ rb_cHnswlibInnerProductSpace = rb_define_class_under(outer, "InnerProductSpace", rb_cObject);
133
135
  rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
134
136
  rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
135
137
  rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
@@ -218,10 +220,11 @@ public:
218
220
  return ptr;
219
221
  };
220
222
 
221
- static VALUE define_class(VALUE rb_mHnswlib) {
222
- rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
223
+ static VALUE define_class(VALUE outer) {
224
+ rb_cHnswlibHierarchicalNSW = rb_define_class_under(outer, "HierarchicalNSW", rb_cObject);
223
225
  rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
224
- rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
226
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_initialize), -1);
227
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "init_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init_index), -1);
225
228
  rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
226
229
  rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
227
230
  rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
@@ -244,59 +247,90 @@ public:
244
247
  private:
245
248
  static const rb_data_type_t hnsw_hierarchicalnsw_type;
246
249
 
247
- static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
250
+ static VALUE _hnsw_hierarchicalnsw_initialize(int argc, VALUE* argv, VALUE self) {
248
251
  VALUE kw_args = Qnil;
249
- ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
250
- rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
251
- VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
252
+ ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
253
+ VALUE kw_values[2] = {Qundef, Qundef};
252
254
  rb_scan_args(argc, argv, ":", &kw_args);
253
- rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
254
- if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
255
- if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
256
- if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
257
- if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
255
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
258
256
 
259
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
260
- rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
261
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
257
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
258
+ rb_raise(rb_eTypeError, "expected space, String");
259
+ return Qnil;
260
+ }
261
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
262
+ strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
263
+ rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
262
264
  return Qnil;
263
265
  }
264
266
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
267
+ rb_raise(rb_eTypeError, "expected dim, Integer");
268
+ return Qnil;
269
+ }
270
+
271
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
272
+ rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
273
+ } else {
274
+ rb_iv_set(self, "@space",
275
+ rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
276
+ }
277
+
278
+ rb_iv_set(self, "@normalize", Qfalse);
279
+ if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
280
+
281
+ return Qnil;
282
+ };
283
+
284
+ static VALUE _hnsw_hierarchicalnsw_init_index(int argc, VALUE* argv, VALUE self) {
285
+ VALUE kw_args = Qnil;
286
+ ID kw_table[5] = {rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"), rb_intern("random_seed"),
287
+ rb_intern("allow_replace_deleted")};
288
+ VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
289
+ rb_scan_args(argc, argv, ":", &kw_args);
290
+ rb_get_kwargs(kw_args, kw_table, 1, 4, kw_values);
291
+ if (kw_values[1] == Qundef) kw_values[1] = SIZET2NUM(16);
292
+ if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(200);
293
+ if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(100);
294
+ if (kw_values[4] == Qundef) kw_values[4] = Qfalse;
295
+
296
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
265
297
  rb_raise(rb_eTypeError, "expected max_elements, Integer");
266
298
  return Qnil;
267
299
  }
268
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
300
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
269
301
  rb_raise(rb_eTypeError, "expected m, Integer");
270
302
  return Qnil;
271
303
  }
272
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
304
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
273
305
  rb_raise(rb_eTypeError, "expected ef_construction, Integer");
274
306
  return Qnil;
275
307
  }
276
- if (!RB_INTEGER_TYPE_P(kw_values[4])) {
308
+ if (!RB_INTEGER_TYPE_P(kw_values[3])) {
277
309
  rb_raise(rb_eTypeError, "expected random_seed, Integer");
278
310
  return Qnil;
279
311
  }
280
- if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
312
+ if (!RB_TYPE_P(kw_values[4], T_TRUE) && !RB_TYPE_P(kw_values[4], T_FALSE)) {
281
313
  rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
282
314
  return Qnil;
283
315
  }
284
316
 
285
- rb_iv_set(self, "@space", kw_values[0]);
286
- hnswlib::SpaceInterface<float>* space;
287
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
288
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
317
+ hnswlib::SpaceInterface<float>* space = nullptr;
318
+ VALUE ivspace = rb_iv_get(self, "@space");
319
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
320
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
289
321
  } else {
290
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
322
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
291
323
  }
292
- const size_t max_elements = NUM2SIZET(kw_values[1]);
293
- const size_t m = NUM2SIZET(kw_values[2]);
294
- const size_t ef_construction = NUM2SIZET(kw_values[3]);
295
- const size_t random_seed = NUM2SIZET(kw_values[4]);
296
- const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
324
+
325
+ const size_t max_elements = NUM2SIZET(kw_values[0]);
326
+ const size_t m = NUM2SIZET(kw_values[1]);
327
+ const size_t ef_construction = NUM2SIZET(kw_values[2]);
328
+ const size_t random_seed = NUM2SIZET(kw_values[3]);
329
+ const bool allow_replace_deleted = kw_values[4] == Qtrue ? true : false;
297
330
 
298
331
  hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
299
332
  try {
333
+ ptr->~HierarchicalNSW();
300
334
  new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
301
335
  } catch (const std::runtime_error& e) {
302
336
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -335,20 +369,29 @@ private:
335
369
  return Qfalse;
336
370
  }
337
371
 
338
- float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
339
- for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
372
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
373
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
340
374
  const size_t idx = NUM2SIZET(_idx);
341
375
  const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
342
376
 
377
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
378
+ float norm = 0.0;
379
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
380
+ norm = std::sqrt(std::fabs(norm));
381
+ if (norm >= 0.0) {
382
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
383
+ }
384
+ }
385
+
343
386
  try {
344
- get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
387
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, idx, replace_deleted);
345
388
  } catch (const std::runtime_error& e) {
346
- ruby_xfree(arr);
389
+ ruby_xfree(vec);
347
390
  rb_raise(rb_eRuntimeError, "%s", e.what());
348
391
  return Qfalse;
349
392
  }
350
393
 
351
- ruby_xfree(arr);
394
+ ruby_xfree(vec);
352
395
  return Qtrue;
353
396
  };
354
397
 
@@ -392,6 +435,15 @@ private:
392
435
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
393
436
  }
394
437
 
438
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
439
+ float norm = 0.0;
440
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
441
+ norm = std::sqrt(std::fabs(norm));
442
+ if (norm >= 0.0) {
443
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
444
+ }
445
+ }
446
+
395
447
  std::priority_queue<std::pair<float, size_t>> result;
396
448
  try {
397
449
  result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
@@ -607,10 +659,11 @@ public:
607
659
  return ptr;
608
660
  };
609
661
 
610
- static VALUE define_class(VALUE rb_mHnswlib) {
611
- rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
662
+ static VALUE define_class(VALUE outer) {
663
+ rb_cHnswlibBruteforceSearch = rb_define_class_under(outer, "BruteforceSearch", rb_cObject);
612
664
  rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
613
- rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
665
+ rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_initialize), -1);
666
+ rb_define_method(rb_cHnswlibBruteforceSearch, "init_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init_index), -1);
614
667
  rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
615
668
  rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
616
669
  rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
@@ -625,34 +678,66 @@ public:
625
678
  private:
626
679
  static const rb_data_type_t hnsw_bruteforcesearch_type;
627
680
 
628
- static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
681
+ static VALUE _hnsw_bruteforcesearch_initialize(int argc, VALUE* argv, VALUE self) {
629
682
  VALUE kw_args = Qnil;
630
- ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
683
+ ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
631
684
  VALUE kw_values[2] = {Qundef, Qundef};
632
685
  rb_scan_args(argc, argv, ":", &kw_args);
633
686
  rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
634
687
 
635
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
636
- rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
637
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
688
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
689
+ rb_raise(rb_eTypeError, "expected space, String");
690
+ return Qnil;
691
+ }
692
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
693
+ strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
694
+ rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
638
695
  return Qnil;
639
696
  }
640
697
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
641
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
698
+ rb_raise(rb_eTypeError, "expected dim, Integer");
642
699
  return Qnil;
643
700
  }
644
701
 
645
- rb_iv_set(self, "@space", kw_values[0]);
646
702
  hnswlib::SpaceInterface<float>* space;
647
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
648
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
703
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
704
+ rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
705
+ } else {
706
+ rb_iv_set(self, "@space",
707
+ rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
708
+ }
709
+
710
+ rb_iv_set(self, "@normalize", Qfalse);
711
+ if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
712
+
713
+ return Qnil;
714
+ };
715
+
716
+ static VALUE _hnsw_bruteforcesearch_init_index(int argc, VALUE* argv, VALUE self) {
717
+ VALUE kw_args = Qnil;
718
+ ID kw_table[1] = {rb_intern("max_elements")};
719
+ VALUE kw_values[1] = {Qundef};
720
+ rb_scan_args(argc, argv, ":", &kw_args);
721
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
722
+
723
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
724
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
725
+ return Qnil;
726
+ }
727
+
728
+ hnswlib::SpaceInterface<float>* space = nullptr;
729
+ VALUE ivspace = rb_iv_get(self, "@space");
730
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
731
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
649
732
  } else {
650
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
733
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
651
734
  }
652
- const size_t max_elements = NUM2SIZET(kw_values[1]);
735
+
736
+ const size_t max_elements = NUM2SIZET(kw_values[0]);
653
737
 
654
738
  hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
655
739
  try {
740
+ ptr->~BruteforceSearch();
656
741
  new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
657
742
  } catch (const std::runtime_error& e) {
658
743
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -681,6 +766,15 @@ private:
681
766
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
682
767
  for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
683
768
 
769
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
770
+ float norm = 0.0;
771
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
772
+ norm = std::sqrt(std::fabs(norm));
773
+ if (norm >= 0.0) {
774
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
775
+ }
776
+ }
777
+
684
778
  try {
685
779
  get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
686
780
  } catch (const std::runtime_error& e) {
@@ -729,8 +823,15 @@ private:
729
823
  }
730
824
 
731
825
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
732
- for (size_t i = 0; i < dim; i++) {
733
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
826
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
827
+
828
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
829
+ float norm = 0.0;
830
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
831
+ norm = std::sqrt(std::fabs(norm));
832
+ if (norm >= 0.0) {
833
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
834
+ }
734
835
  }
735
836
 
736
837
  std::priority_queue<std::pair<float, size_t>> result =
@@ -3,7 +3,7 @@
3
3
  # Hnswlib.rb provides Ruby bindings for the Hnswlib.
4
4
  module Hnswlib
5
5
  # The version of Hnswlib.rb you install.
6
- VERSION = '0.7.0'
6
+ VERSION = '0.8.1'
7
7
 
8
8
  # The version of Hnswlib included with gem.
9
9
  HSWLIB_VERSION = '0.7.0'
data/lib/hnswlib.rb CHANGED
@@ -18,6 +18,7 @@ module Hnswlib
18
18
  #
19
19
  # index.get_nns_by_item(0, 100)
20
20
  #
21
+ # @deprecated This class was prepared as a class with an interface similar to Annoy, but it is not very useful and will be deleted in the next version.
21
22
  class HnswIndex
22
23
  # Returns the metric of index.
23
24
  # @return [String]
@@ -27,7 +28,7 @@ module Hnswlib
27
28
  #
28
29
  # @param n_features [Integer] The number of features (dimensions) of stored vector.
29
30
  # @param max_item [Integer] The maximum number of items.
30
- # @param metric [String] The distance metric between vectors ('l2' or 'dot').
31
+ # @param metric [String] The distance metric between vectors ('l2', 'dot', or 'cosine').
31
32
  # @param m [Integer] The maximum number of outgoing connections in the graph
32
33
  # @param ef_construction [Integer] The size of the dynamic list for the nearest neighbors. It controls the index time/accuracy trade-off.
33
34
  # @param random_seed [Integer] The seed value using to initialize the random generator.
@@ -35,15 +36,10 @@ module Hnswlib
35
36
  def initialize(n_features:, max_item:, metric: 'l2', m: 16, ef_construction: 200,
36
37
  random_seed: 100, allow_replace_removed: false)
37
38
  @metric = metric
38
- space = if @metric == 'dot'
39
- Hnswlib::InnerProductSpace.new(n_features)
40
- else
41
- Hnswlib::L2Space.new(n_features)
42
- end
43
- @index = Hnswlib::HierarchicalNSW.new(
44
- space: space, max_elements: max_item, m: m, ef_construction: ef_construction,
45
- random_seed: random_seed, allow_replace_deleted: allow_replace_removed
46
- )
39
+ space = @metric == 'dot' ? 'ip' : 'l2'
40
+ @index = Hnswlib::HierarchicalNSW.new(space: space, dim: n_features)
41
+ @index.init_index(max_elements: max_item, m: m, ef_construction: ef_construction,
42
+ random_seed: random_seed, allow_replace_deleted: allow_replace_removed)
47
43
  end
48
44
 
49
45
  # Add item to be indexed.
data/sig/hnswlib.rbs CHANGED
@@ -40,30 +40,36 @@ module Hnswlib
40
40
  class BruteforceSearch
41
41
  attr_accessor space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace)
42
42
 
43
- def initialize: (space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace) space, max_elements: Integer max_elements) -> void
43
+ def initialize: (space: String space, dim: Integer dim) -> void
44
+ def init_index: (max_elements: Integer max_elements) -> void
44
45
  def add_point: (Array[Float] arr, Integer idx) -> bool
45
46
  def current_count: () -> Integer
46
47
  def load_index: (String filename) -> void
47
48
  def max_elements: () -> Integer
48
49
  def remove_point: (Integer idx) -> void
49
50
  def save_index: (String filename) -> void
50
- def search_knn: (Array[Float] arr, Integer k) -> [Array[Integer], Array[Float]]
51
+ def search_knn: (Array[Float] arr, Integer k, ?filter: Proc filter) -> [Array[Integer], Array[Float]]
51
52
  end
52
53
 
53
54
  class HierarchicalNSW
54
55
  attr_accessor space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace)
55
56
 
56
- def initialize: (space: (::Hnswlib::L2Space | ::Hnswlib::InnerProductSpace) space, max_elements: Integer max_elements, ?m: Integer m, ?ef_construction: Integer ef_construction, ?random_seed: Integer random_seed, ?allow_replace_deleted: (true | false) allow_replace_deleted) -> void
57
+ def initialize: (space: String space, dim: Integer dim) -> void
58
+ def init_index: (max_elements: Integer max_elements, ?m: Integer m, ?ef_construction: Integer ef_construction, ?random_seed: Integer random_seed, ?allow_replace_deleted: (true | false) allow_replace_deleted) -> void
57
59
  def add_point: (Array[Float] arr, Integer idx, ?replace_deleted: (true | false) replace_deleted) -> bool
58
60
  def current_count: () -> Integer
59
61
  def get_ids: () -> Array[Integer]
60
62
  def get_point: (Integer idx) -> Array[Float]
61
63
  def load_index: (String filename, ?allow_replace_deleted: (true | false) allow_replace_deleted) -> void
62
64
  def mark_deleted: (Integer idx) -> void
65
+ def unmark_deleted: (Integer idx) -> void
63
66
  def max_elements: () -> Integer
64
67
  def resize_index: (Integer new_max_elements) -> void
65
68
  def save_index: (String filename) -> void
66
- def search_knn: (Array[Float] arr, Integer k) -> [Array[Integer], Array[Float]]
69
+ def search_knn: (Array[Float] arr, Integer k, ?filter: Proc filter) -> [Array[Integer], Array[Float]]
67
70
  def set_ef: (Integer ef) -> void
71
+ def get_ef: () -> Integer
72
+ def ef_construction: () -> Integer
73
+ def m: () -> Integer
68
74
  end
69
75
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: hnswlib
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.7.0
4
+ version: 0.8.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-04 00:00:00.000000000 Z
11
+ date: 2023-03-18 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: Hnswlib.rb provides Ruby bindings for the Hnswlib.
14
14
  email: