hnswlib 0.8.1 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33dd2f9cd8656dfe469b4b9baf319ae4e0b78826dec58da8ae78516077a6fbc2
4
- data.tar.gz: ad32e54a325136587ae12716243b88ed1269c7bc9cd4e92e9aa8896b2517d4d2
3
+ metadata.gz: 39458b40736b4a330a0c1769a3adcf37d9d40614b1270d33885932244e943b74
4
+ data.tar.gz: d49bf8e158c55235fdeca0656cd3ddab6a07e5abc1468eb8917b21f9baf5d89c
5
5
  SHA512:
6
- metadata.gz: 361822ebf216d8f8ba3abf8314e52b1a72887fb2a03aea29572940376326b128442fc007fa7de6a6e925facb0b9295076fb10a98ff053aad7e0fbe41e4973c2a
7
- data.tar.gz: 38f9f4b10665d9e87940b23fae51a15640229367f6bb74a7dba176bab6e798c8fad8abce106c07e38e0acc9046949f80208f9f9b30e3341aceb5a6193dac3aca
6
+ metadata.gz: ec8560f2b7aa9bb5df098b9c863e4602465a42e8ea27a19be55b9d62e9ba0d32fa70ccbcb13323b771370df5ded9370f9f91d17c8d63aecc18d2b45e29df0149
7
+ data.tar.gz: c9b434fb313f41e2f0570070944eb6eb3b4432a65fce2c63ca6592ab1de30f9b04834d87e1bd40db465dceea0eaac2279378eda537f6b7d51c2bdb0aa06e1ab6
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## [[0.9.0](https://github.com/yoshoku/hnswlib.rb/compare/v0.8.1...v0.9.0)] - 2023-12-16
2
+
3
+ - Update bundled hnswlib version to 0.8.0.
4
+ - Multi-vector document search and epsilon search, which are added only to the C++ version, are not supported. These features will be supported in future release.
5
+
1
6
  ## [0.8.1] - 2023-03-18
2
7
 
3
8
  - Update the type declarations of HierarchicalNSW and BruteforceSearch along with recent changes.
@@ -528,10 +528,6 @@ private:
528
528
  free(index->linkLists_);
529
529
  index->linkLists_ = nullptr;
530
530
  }
531
- if (index->visited_list_pool_) {
532
- delete index->visited_list_pool_;
533
- index->visited_list_pool_ = nullptr;
534
- }
535
531
 
536
532
  try {
537
533
  index->loadIndex(filename, space);
@@ -85,10 +85,16 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
85
85
 
86
86
 
87
87
  void removePoint(labeltype cur_external) {
88
- size_t cur_c = dict_external_to_internal[cur_external];
88
+ std::unique_lock<std::mutex> lock(index_lock);
89
89
 
90
- dict_external_to_internal.erase(cur_external);
90
+ auto found = dict_external_to_internal.find(cur_external);
91
+ if (found == dict_external_to_internal.end()) {
92
+ return;
93
+ }
94
+
95
+ dict_external_to_internal.erase(found);
91
96
 
97
+ size_t cur_c = found->second;
92
98
  labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
93
99
  dict_external_to_internal[label] = cur_c;
94
100
  memcpy(data_ + size_per_element_ * cur_c,
@@ -107,7 +113,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
107
113
  dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
108
114
  labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
109
115
  if ((!isIdAllowed) || (*isIdAllowed)(label)) {
110
- topResults.push(std::pair<dist_t, labeltype>(dist, label));
116
+ topResults.emplace(dist, label);
111
117
  }
112
118
  }
113
119
  dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
@@ -116,7 +122,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
116
122
  if (dist <= lastdist) {
117
123
  labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
118
124
  if ((!isIdAllowed) || (*isIdAllowed)(label)) {
119
- topResults.push(std::pair<dist_t, labeltype>(dist, label));
125
+ topResults.emplace(dist, label);
120
126
  }
121
127
  if (topResults.size() > k)
122
128
  topResults.pop();
@@ -8,6 +8,7 @@
8
8
  #include <assert.h>
9
9
  #include <unordered_set>
10
10
  #include <list>
11
+ #include <memory>
11
12
 
12
13
  namespace hnswlib {
13
14
  typedef unsigned int tableint;
@@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
33
34
  double mult_{0.0}, revSize_{0.0};
34
35
  int maxlevel_{0};
35
36
 
36
- VisitedListPool *visited_list_pool_{nullptr};
37
+ std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};
37
38
 
38
39
  // Locks operations with element by label value
39
40
  mutable std::vector<std::mutex> label_op_locks_;
@@ -93,8 +94,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
93
94
  size_t ef_construction = 200,
94
95
  size_t random_seed = 100,
95
96
  bool allow_replace_deleted = false)
96
- : link_list_locks_(max_elements),
97
- label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
97
+ : label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
98
+ link_list_locks_(max_elements),
98
99
  element_levels_(max_elements),
99
100
  allow_replace_deleted_(allow_replace_deleted) {
100
101
  max_elements_ = max_elements;
@@ -102,7 +103,13 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
102
103
  data_size_ = s->get_data_size();
103
104
  fstdistfunc_ = s->get_dist_func();
104
105
  dist_func_param_ = s->get_dist_func_param();
105
- M_ = M;
106
+ if ( M <= 10000 ) {
107
+ M_ = M;
108
+ } else {
109
+ HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
110
+ HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
111
+ M_ = 10000;
112
+ }
106
113
  maxM_ = M_;
107
114
  maxM0_ = M_ * 2;
108
115
  ef_construction_ = std::max(ef_construction, M_);
@@ -123,7 +130,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
123
130
 
124
131
  cur_element_count = 0;
125
132
 
126
- visited_list_pool_ = new VisitedListPool(1, max_elements);
133
+ visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));
127
134
 
128
135
  // initializations for special treatment of the first node
129
136
  enterpoint_node_ = -1;
@@ -139,13 +146,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
139
146
 
140
147
 
141
148
  ~HierarchicalNSW() {
149
+ clear();
150
+ }
151
+
152
+ void clear() {
142
153
  free(data_level0_memory_);
154
+ data_level0_memory_ = nullptr;
143
155
  for (tableint i = 0; i < cur_element_count; i++) {
144
156
  if (element_levels_[i] > 0)
145
157
  free(linkLists_[i]);
146
158
  }
147
159
  free(linkLists_);
148
- delete visited_list_pool_;
160
+ linkLists_ = nullptr;
161
+ cur_element_count = 0;
162
+ visited_list_pool_.reset(nullptr);
149
163
  }
150
164
 
151
165
 
@@ -292,9 +306,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
292
306
  }
293
307
 
294
308
 
295
- template <bool has_deletions, bool collect_metrics = false>
309
+ // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
310
+ template <bool bare_bone_search = true, bool collect_metrics = false>
296
311
  std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
297
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
312
+ searchBaseLayerST(
313
+ tableint ep_id,
314
+ const void *data_point,
315
+ size_t ef,
316
+ BaseFilterFunctor* isIdAllowed = nullptr,
317
+ BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
298
318
  VisitedList *vl = visited_list_pool_->getFreeVisitedList();
299
319
  vl_type *visited_array = vl->mass;
300
320
  vl_type visited_array_tag = vl->curV;
@@ -303,10 +323,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
303
323
  std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
304
324
 
305
325
  dist_t lowerBound;
306
- if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
307
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
326
+ if (bare_bone_search ||
327
+ (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
328
+ char* ep_data = getDataByInternalId(ep_id);
329
+ dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
308
330
  lowerBound = dist;
309
331
  top_candidates.emplace(dist, ep_id);
332
+ if (!bare_bone_search && stop_condition) {
333
+ stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
334
+ }
310
335
  candidate_set.emplace(-dist, ep_id);
311
336
  } else {
312
337
  lowerBound = std::numeric_limits<dist_t>::max();
@@ -317,9 +342,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
317
342
 
318
343
  while (!candidate_set.empty()) {
319
344
  std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
345
+ dist_t candidate_dist = -current_node_pair.first;
320
346
 
321
- if ((-current_node_pair.first) > lowerBound &&
322
- (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
347
+ bool flag_stop_search;
348
+ if (bare_bone_search) {
349
+ flag_stop_search = candidate_dist > lowerBound;
350
+ } else {
351
+ if (stop_condition) {
352
+ flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
353
+ } else {
354
+ flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
355
+ }
356
+ }
357
+ if (flag_stop_search) {
323
358
  break;
324
359
  }
325
360
  candidate_set.pop();
@@ -354,7 +389,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
354
389
  char *currObj1 = (getDataByInternalId(candidate_id));
355
390
  dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
356
391
 
357
- if (top_candidates.size() < ef || lowerBound > dist) {
392
+ bool flag_consider_candidate;
393
+ if (!bare_bone_search && stop_condition) {
394
+ flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
395
+ } else {
396
+ flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
397
+ }
398
+
399
+ if (flag_consider_candidate) {
358
400
  candidate_set.emplace(-dist, candidate_id);
359
401
  #ifdef USE_SSE
360
402
  _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
@@ -362,11 +404,30 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
362
404
  _MM_HINT_T0); ////////////////////////
363
405
  #endif
364
406
 
365
- if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
407
+ if (bare_bone_search ||
408
+ (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
366
409
  top_candidates.emplace(dist, candidate_id);
410
+ if (!bare_bone_search && stop_condition) {
411
+ stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
412
+ }
413
+ }
367
414
 
368
- if (top_candidates.size() > ef)
415
+ bool flag_remove_extra = false;
416
+ if (!bare_bone_search && stop_condition) {
417
+ flag_remove_extra = stop_condition->should_remove_extra();
418
+ } else {
419
+ flag_remove_extra = top_candidates.size() > ef;
420
+ }
421
+ while (flag_remove_extra) {
422
+ tableint id = top_candidates.top().second;
369
423
  top_candidates.pop();
424
+ if (!bare_bone_search && stop_condition) {
425
+ stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
426
+ flag_remove_extra = stop_condition->should_remove_extra();
427
+ } else {
428
+ flag_remove_extra = top_candidates.size() > ef;
429
+ }
430
+ }
370
431
 
371
432
  if (!top_candidates.empty())
372
433
  lowerBound = top_candidates.top().first;
@@ -381,8 +442,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
381
442
 
382
443
 
383
444
  void getNeighborsByHeuristic2(
384
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
385
- const size_t M) {
445
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
446
+ const size_t M) {
386
447
  if (top_candidates.size() < M) {
387
448
  return;
388
449
  }
@@ -574,8 +635,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
574
635
  if (new_max_elements < cur_element_count)
575
636
  throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
576
637
 
577
- delete visited_list_pool_;
578
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
638
+ visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));
579
639
 
580
640
  element_levels_.resize(new_max_elements);
581
641
 
@@ -596,6 +656,32 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
596
656
  max_elements_ = new_max_elements;
597
657
  }
598
658
 
659
+ size_t indexFileSize() const {
660
+ size_t size = 0;
661
+ size += sizeof(offsetLevel0_);
662
+ size += sizeof(max_elements_);
663
+ size += sizeof(cur_element_count);
664
+ size += sizeof(size_data_per_element_);
665
+ size += sizeof(label_offset_);
666
+ size += sizeof(offsetData_);
667
+ size += sizeof(maxlevel_);
668
+ size += sizeof(enterpoint_node_);
669
+ size += sizeof(maxM_);
670
+
671
+ size += sizeof(maxM0_);
672
+ size += sizeof(M_);
673
+ size += sizeof(mult_);
674
+ size += sizeof(ef_construction_);
675
+
676
+ size += cur_element_count * size_data_per_element_;
677
+
678
+ for (size_t i = 0; i < cur_element_count; i++) {
679
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
680
+ size += sizeof(linkListSize);
681
+ size += linkListSize;
682
+ }
683
+ return size;
684
+ }
599
685
 
600
686
  void saveIndex(const std::string &location) {
601
687
  std::ofstream output(location, std::ios::binary);
@@ -634,6 +720,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
634
720
  if (!input.is_open())
635
721
  throw std::runtime_error("Cannot open file");
636
722
 
723
+ clear();
637
724
  // get file size:
638
725
  input.seekg(0, input.end);
639
726
  std::streampos total_filesize = input.tellg();
@@ -699,7 +786,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
699
786
  std::vector<std::mutex>(max_elements).swap(link_list_locks_);
700
787
  std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
701
788
 
702
- visited_list_pool_ = new VisitedListPool(1, max_elements);
789
+ visited_list_pool_.reset(new VisitedListPool(1, max_elements));
703
790
 
704
791
  linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
705
792
  if (linkLists_ == nullptr)
@@ -753,7 +840,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
753
840
  size_t dim = *((size_t *) dist_func_param_);
754
841
  std::vector<data_t> data;
755
842
  data_t* data_ptr = (data_t*) data_ptrv;
756
- for (int i = 0; i < dim; i++) {
843
+ for (size_t i = 0; i < dim; i++) {
757
844
  data.push_back(*data_ptr);
758
845
  data_ptr += 1;
759
846
  }
@@ -1217,11 +1304,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1217
1304
  }
1218
1305
 
1219
1306
  std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1220
- if (num_deleted_) {
1221
- top_candidates = searchBaseLayerST<true, true>(
1307
+ bool bare_bone_search = !num_deleted_ && !isIdAllowed;
1308
+ if (bare_bone_search) {
1309
+ top_candidates = searchBaseLayerST<true>(
1222
1310
  currObj, query_data, std::max(ef_, k), isIdAllowed);
1223
1311
  } else {
1224
- top_candidates = searchBaseLayerST<false, true>(
1312
+ top_candidates = searchBaseLayerST<false>(
1225
1313
  currObj, query_data, std::max(ef_, k), isIdAllowed);
1226
1314
  }
1227
1315
 
@@ -1237,6 +1325,60 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1237
1325
  }
1238
1326
 
1239
1327
 
1328
+ std::vector<std::pair<dist_t, labeltype >>
1329
+ searchStopConditionClosest(
1330
+ const void *query_data,
1331
+ BaseSearchStopCondition<dist_t>& stop_condition,
1332
+ BaseFilterFunctor* isIdAllowed = nullptr) const {
1333
+ std::vector<std::pair<dist_t, labeltype >> result;
1334
+ if (cur_element_count == 0) return result;
1335
+
1336
+ tableint currObj = enterpoint_node_;
1337
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1338
+
1339
+ for (int level = maxlevel_; level > 0; level--) {
1340
+ bool changed = true;
1341
+ while (changed) {
1342
+ changed = false;
1343
+ unsigned int *data;
1344
+
1345
+ data = (unsigned int *) get_linklist(currObj, level);
1346
+ int size = getListCount(data);
1347
+ metric_hops++;
1348
+ metric_distance_computations+=size;
1349
+
1350
+ tableint *datal = (tableint *) (data + 1);
1351
+ for (int i = 0; i < size; i++) {
1352
+ tableint cand = datal[i];
1353
+ if (cand < 0 || cand > max_elements_)
1354
+ throw std::runtime_error("cand error");
1355
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1356
+
1357
+ if (d < curdist) {
1358
+ curdist = d;
1359
+ currObj = cand;
1360
+ changed = true;
1361
+ }
1362
+ }
1363
+ }
1364
+ }
1365
+
1366
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1367
+ top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition);
1368
+
1369
+ size_t sz = top_candidates.size();
1370
+ result.resize(sz);
1371
+ while (!top_candidates.empty()) {
1372
+ result[--sz] = top_candidates.top();
1373
+ top_candidates.pop();
1374
+ }
1375
+
1376
+ stop_condition.filter_results(result);
1377
+
1378
+ return result;
1379
+ }
1380
+
1381
+
1240
1382
  void checkIntegrity() {
1241
1383
  int connections_checked = 0;
1242
1384
  std::vector <int > inbound_connections_num(cur_element_count, 0);
@@ -1247,7 +1389,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1247
1389
  tableint *data = (tableint *) (ll_cur + 1);
1248
1390
  std::unordered_set<tableint> s;
1249
1391
  for (int j = 0; j < size; j++) {
1250
- assert(data[j] > 0);
1251
1392
  assert(data[j] < cur_element_count);
1252
1393
  assert(data[j] != i);
1253
1394
  inbound_connections_num[data[j]]++;
@@ -1,4 +1,13 @@
1
1
  #pragma once
2
+
3
+ // https://github.com/nmslib/hnswlib/pull/508
4
+ // This allows others to provide their own error stream (e.g. RcppHNSW)
5
+ #ifndef HNSWLIB_ERR_OVERRIDE
6
+ #define HNSWERR std::cerr
7
+ #else
8
+ #define HNSWERR HNSWLIB_ERR_OVERRIDE
9
+ #endif
10
+
2
11
  #ifndef NO_MANUAL_VECTORIZATION
3
12
  #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
4
13
  #define USE_SSE
@@ -15,7 +24,7 @@
15
24
  #ifdef _MSC_VER
16
25
  #include <intrin.h>
17
26
  #include <stdexcept>
18
- void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
27
+ static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
19
28
  __cpuidex(out, eax, ecx);
20
29
  }
21
30
  static __int64 xgetbv(unsigned int x) {
@@ -122,6 +131,24 @@ class BaseFilterFunctor {
122
131
  virtual ~BaseFilterFunctor() {};
123
132
  };
124
133
 
134
+ template<typename dist_t>
135
+ class BaseSearchStopCondition {
136
+ public:
137
+ virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0;
138
+
139
+ virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0;
140
+
141
+ virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0;
142
+
143
+ virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0;
144
+
145
+ virtual bool should_remove_extra() = 0;
146
+
147
+ virtual void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) = 0;
148
+
149
+ virtual ~BaseSearchStopCondition() {}
150
+ };
151
+
125
152
  template <typename T>
126
153
  class pairGreater {
127
154
  public:
@@ -196,5 +223,6 @@ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t
196
223
 
197
224
  #include "space_l2.h"
198
225
  #include "space_ip.h"
226
+ #include "stop_condition.h"
199
227
  #include "bruteforce.h"
200
228
  #include "hnswalg.h"
@@ -157,19 +157,44 @@ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void
157
157
 
158
158
  __m512 sum512 = _mm512_set1_ps(0);
159
159
 
160
- while (pVect1 < pEnd1) {
161
- //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
160
+ size_t loop = qty16 / 4;
162
161
 
162
+ while (loop--) {
163
163
  __m512 v1 = _mm512_loadu_ps(pVect1);
164
- pVect1 += 16;
165
164
  __m512 v2 = _mm512_loadu_ps(pVect2);
165
+ pVect1 += 16;
166
+ pVect2 += 16;
167
+
168
+ __m512 v3 = _mm512_loadu_ps(pVect1);
169
+ __m512 v4 = _mm512_loadu_ps(pVect2);
170
+ pVect1 += 16;
171
+ pVect2 += 16;
172
+
173
+ __m512 v5 = _mm512_loadu_ps(pVect1);
174
+ __m512 v6 = _mm512_loadu_ps(pVect2);
175
+ pVect1 += 16;
166
176
  pVect2 += 16;
167
- sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
177
+
178
+ __m512 v7 = _mm512_loadu_ps(pVect1);
179
+ __m512 v8 = _mm512_loadu_ps(pVect2);
180
+ pVect1 += 16;
181
+ pVect2 += 16;
182
+
183
+ sum512 = _mm512_fmadd_ps(v1, v2, sum512);
184
+ sum512 = _mm512_fmadd_ps(v3, v4, sum512);
185
+ sum512 = _mm512_fmadd_ps(v5, v6, sum512);
186
+ sum512 = _mm512_fmadd_ps(v7, v8, sum512);
168
187
  }
169
188
 
170
- _mm512_store_ps(TmpRes, sum512);
171
- float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
189
+ while (pVect1 < pEnd1) {
190
+ __m512 v1 = _mm512_loadu_ps(pVect1);
191
+ __m512 v2 = _mm512_loadu_ps(pVect2);
192
+ pVect1 += 16;
193
+ pVect2 += 16;
194
+ sum512 = _mm512_fmadd_ps(v1, v2, sum512);
195
+ }
172
196
 
197
+ float sum = _mm512_reduce_add_ps(sum512);
173
198
  return sum;
174
199
  }
175
200
 
@@ -0,0 +1,276 @@
1
+ #pragma once
2
+ #include "space_l2.h"
3
+ #include "space_ip.h"
4
+ #include <assert.h>
5
+ #include <unordered_map>
6
+
7
+ namespace hnswlib {
8
+
9
+ template<typename DOCIDTYPE>
10
+ class BaseMultiVectorSpace : public SpaceInterface<float> {
11
+ public:
12
+ virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0;
13
+
14
+ virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0;
15
+ };
16
+
17
+
18
+ template<typename DOCIDTYPE>
19
+ class MultiVectorL2Space : public BaseMultiVectorSpace<DOCIDTYPE> {
20
+ DISTFUNC<float> fstdistfunc_;
21
+ size_t data_size_;
22
+ size_t vector_size_;
23
+ size_t dim_;
24
+
25
+ public:
26
+ MultiVectorL2Space(size_t dim) {
27
+ fstdistfunc_ = L2Sqr;
28
+ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
29
+ #if defined(USE_AVX512)
30
+ if (AVX512Capable())
31
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
32
+ else if (AVXCapable())
33
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
34
+ #elif defined(USE_AVX)
35
+ if (AVXCapable())
36
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
37
+ #endif
38
+
39
+ if (dim % 16 == 0)
40
+ fstdistfunc_ = L2SqrSIMD16Ext;
41
+ else if (dim % 4 == 0)
42
+ fstdistfunc_ = L2SqrSIMD4Ext;
43
+ else if (dim > 16)
44
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
45
+ else if (dim > 4)
46
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
47
+ #endif
48
+ dim_ = dim;
49
+ vector_size_ = dim * sizeof(float);
50
+ data_size_ = vector_size_ + sizeof(DOCIDTYPE);
51
+ }
52
+
53
+ size_t get_data_size() override {
54
+ return data_size_;
55
+ }
56
+
57
+ DISTFUNC<float> get_dist_func() override {
58
+ return fstdistfunc_;
59
+ }
60
+
61
+ void *get_dist_func_param() override {
62
+ return &dim_;
63
+ }
64
+
65
+ DOCIDTYPE get_doc_id(const void *datapoint) override {
66
+ return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
67
+ }
68
+
69
+ void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
70
+ *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
71
+ }
72
+
73
+ ~MultiVectorL2Space() {}
74
+ };
75
+
76
+
77
+ template<typename DOCIDTYPE>
78
+ class MultiVectorInnerProductSpace : public BaseMultiVectorSpace<DOCIDTYPE> {
79
+ DISTFUNC<float> fstdistfunc_;
80
+ size_t data_size_;
81
+ size_t vector_size_;
82
+ size_t dim_;
83
+
84
+ public:
85
+ MultiVectorInnerProductSpace(size_t dim) {
86
+ fstdistfunc_ = InnerProductDistance;
87
+ #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
88
+ #if defined(USE_AVX512)
89
+ if (AVX512Capable()) {
90
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
91
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
92
+ } else if (AVXCapable()) {
93
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
94
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
95
+ }
96
+ #elif defined(USE_AVX)
97
+ if (AVXCapable()) {
98
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
99
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
100
+ }
101
+ #endif
102
+ #if defined(USE_AVX)
103
+ if (AVXCapable()) {
104
+ InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
105
+ InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
106
+ }
107
+ #endif
108
+
109
+ if (dim % 16 == 0)
110
+ fstdistfunc_ = InnerProductDistanceSIMD16Ext;
111
+ else if (dim % 4 == 0)
112
+ fstdistfunc_ = InnerProductDistanceSIMD4Ext;
113
+ else if (dim > 16)
114
+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
115
+ else if (dim > 4)
116
+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
117
+ #endif
118
+ vector_size_ = dim * sizeof(float);
119
+ data_size_ = vector_size_ + sizeof(DOCIDTYPE);
120
+ }
121
+
122
+ size_t get_data_size() override {
123
+ return data_size_;
124
+ }
125
+
126
+ DISTFUNC<float> get_dist_func() override {
127
+ return fstdistfunc_;
128
+ }
129
+
130
+ void *get_dist_func_param() override {
131
+ return &dim_;
132
+ }
133
+
134
+ DOCIDTYPE get_doc_id(const void *datapoint) override {
135
+ return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
136
+ }
137
+
138
+ void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
139
+ *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
140
+ }
141
+
142
+ ~MultiVectorInnerProductSpace() {}
143
+ };
144
+
145
+
146
+ template<typename DOCIDTYPE, typename dist_t>
147
+ class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
148
+ size_t curr_num_docs_;
149
+ size_t num_docs_to_search_;
150
+ size_t ef_collection_;
151
+ std::unordered_map<DOCIDTYPE, size_t> doc_counter_;
152
+ std::priority_queue<std::pair<dist_t, DOCIDTYPE>> search_results_;
153
+ BaseMultiVectorSpace<DOCIDTYPE>& space_;
154
+
155
+ public:
156
+ MultiVectorSearchStopCondition(
157
+ BaseMultiVectorSpace<DOCIDTYPE>& space,
158
+ size_t num_docs_to_search,
159
+ size_t ef_collection = 10)
160
+ : space_(space) {
161
+ curr_num_docs_ = 0;
162
+ num_docs_to_search_ = num_docs_to_search;
163
+ ef_collection_ = std::max(ef_collection, num_docs_to_search);
164
+ }
165
+
166
+ void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
167
+ DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
168
+ if (doc_counter_[doc_id] == 0) {
169
+ curr_num_docs_ += 1;
170
+ }
171
+ search_results_.emplace(dist, doc_id);
172
+ doc_counter_[doc_id] += 1;
173
+ }
174
+
175
+ void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
176
+ DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
177
+ doc_counter_[doc_id] -= 1;
178
+ if (doc_counter_[doc_id] == 0) {
179
+ curr_num_docs_ -= 1;
180
+ }
181
+ search_results_.pop();
182
+ }
183
+
184
+ bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
185
+ bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_;
186
+ return stop_search;
187
+ }
188
+
189
+ bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
190
+ bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist;
191
+ return flag_consider_candidate;
192
+ }
193
+
194
+ bool should_remove_extra() override {
195
+ bool flag_remove_extra = curr_num_docs_ > ef_collection_;
196
+ return flag_remove_extra;
197
+ }
198
+
199
+ void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
200
+ while (curr_num_docs_ > num_docs_to_search_) {
201
+ dist_t dist_cand = candidates.back().first;
202
+ dist_t dist_res = search_results_.top().first;
203
+ assert(dist_cand == dist_res);
204
+ DOCIDTYPE doc_id = search_results_.top().second;
205
+ doc_counter_[doc_id] -= 1;
206
+ if (doc_counter_[doc_id] == 0) {
207
+ curr_num_docs_ -= 1;
208
+ }
209
+ search_results_.pop();
210
+ candidates.pop_back();
211
+ }
212
+ }
213
+
214
+ ~MultiVectorSearchStopCondition() {}
215
+ };
216
+
217
+
218
+ template<typename dist_t>
219
+ class EpsilonSearchStopCondition : public BaseSearchStopCondition<dist_t> {
220
+ float epsilon_;
221
+ size_t min_num_candidates_;
222
+ size_t max_num_candidates_;
223
+ size_t curr_num_items_;
224
+
225
+ public:
226
+ EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) {
227
+ assert(min_num_candidates <= max_num_candidates);
228
+ epsilon_ = epsilon;
229
+ min_num_candidates_ = min_num_candidates;
230
+ max_num_candidates_ = max_num_candidates;
231
+ curr_num_items_ = 0;
232
+ }
233
+
234
+ void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
235
+ curr_num_items_ += 1;
236
+ }
237
+
238
+ void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
239
+ curr_num_items_ -= 1;
240
+ }
241
+
242
+ bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
243
+ if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) {
244
+ // new candidate can't improve found results
245
+ return true;
246
+ }
247
+ if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) {
248
+ // new candidate is out of epsilon region and
249
+ // minimum number of candidates is checked
250
+ return true;
251
+ }
252
+ return false;
253
+ }
254
+
255
+ bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
256
+ bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist;
257
+ return flag_consider_candidate;
258
+ }
259
+
260
+ bool should_remove_extra() {
261
+ bool flag_remove_extra = curr_num_items_ > max_num_candidates_;
262
+ return flag_remove_extra;
263
+ }
264
+
265
+ void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
266
+ while (!candidates.empty() && candidates.back().first > epsilon_) {
267
+ candidates.pop_back();
268
+ }
269
+ while (candidates.size() > max_num_candidates_) {
270
+ candidates.pop_back();
271
+ }
272
+ }
273
+
274
+ ~EpsilonSearchStopCondition() {}
275
+ };
276
+ } // namespace hnswlib
@@ -3,8 +3,8 @@
3
3
  # Hnswlib.rb provides Ruby bindings for the Hnswlib.
4
4
  module Hnswlib
5
5
  # The version of Hnswlib.rb you install.
6
- VERSION = '0.8.1'
6
+ VERSION = '0.9.0'
7
7
 
8
8
  # The version of Hnswlib included with gem.
9
- HSWLIB_VERSION = '0.7.0'
9
+ HSWLIB_VERSION = '0.8.0'
10
10
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: hnswlib
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.1
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-18 00:00:00.000000000 Z
11
+ date: 2023-12-16 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: Hnswlib.rb provides Ruby bindings for the Hnswlib.
14
14
  email:
@@ -30,6 +30,7 @@ files:
30
30
  - ext/hnswlib/src/hnswlib.h
31
31
  - ext/hnswlib/src/space_ip.h
32
32
  - ext/hnswlib/src/space_l2.h
33
+ - ext/hnswlib/src/stop_condition.h
33
34
  - ext/hnswlib/src/visited_list_pool.h
34
35
  - lib/hnswlib.rb
35
36
  - lib/hnswlib/version.rb
@@ -41,6 +42,7 @@ metadata:
41
42
  homepage_uri: https://github.com/yoshoku/hnswlib.rb
42
43
  source_code_uri: https://github.com/yoshoku/hnswlib.rb
43
44
  changelog_uri: https://github.com/yoshoku/hnswlib.rb/blob/main/CHANGELOG.md
45
+ documentation_uri: https://yoshoku.github.io/hnswlib.rb/doc/
44
46
  rubygems_mfa_required: 'true'
45
47
  post_install_message:
46
48
  rdoc_options: []
@@ -57,7 +59,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
57
59
  - !ruby/object:Gem::Version
58
60
  version: '0'
59
61
  requirements: []
60
- rubygems_version: 3.3.26
62
+ rubygems_version: 3.4.22
61
63
  signing_key:
62
64
  specification_version: 4
63
65
  summary: Ruby bindings for the Hnswlib.