hnswlib 0.8.1 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33dd2f9cd8656dfe469b4b9baf319ae4e0b78826dec58da8ae78516077a6fbc2
4
- data.tar.gz: ad32e54a325136587ae12716243b88ed1269c7bc9cd4e92e9aa8896b2517d4d2
3
+ metadata.gz: 39458b40736b4a330a0c1769a3adcf37d9d40614b1270d33885932244e943b74
4
+ data.tar.gz: d49bf8e158c55235fdeca0656cd3ddab6a07e5abc1468eb8917b21f9baf5d89c
5
5
  SHA512:
6
- metadata.gz: 361822ebf216d8f8ba3abf8314e52b1a72887fb2a03aea29572940376326b128442fc007fa7de6a6e925facb0b9295076fb10a98ff053aad7e0fbe41e4973c2a
7
- data.tar.gz: 38f9f4b10665d9e87940b23fae51a15640229367f6bb74a7dba176bab6e798c8fad8abce106c07e38e0acc9046949f80208f9f9b30e3341aceb5a6193dac3aca
6
+ metadata.gz: ec8560f2b7aa9bb5df098b9c863e4602465a42e8ea27a19be55b9d62e9ba0d32fa70ccbcb13323b771370df5ded9370f9f91d17c8d63aecc18d2b45e29df0149
7
+ data.tar.gz: c9b434fb313f41e2f0570070944eb6eb3b4432a65fce2c63ca6592ab1de30f9b04834d87e1bd40db465dceea0eaac2279378eda537f6b7d51c2bdb0aa06e1ab6
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## [[0.9.0](https://github.com/yoshoku/hnswlib.rb/compare/v0.8.1...v0.9.0)] - 2023-12-16
2
+
3
+ - Update bundled hnswlib version to 0.8.0.
4
+ - Multi-vector document search and epsilon search, which are added only to the C++ version, are not supported. These features will be supported in future release.
5
+
1
6
  ## [0.8.1] - 2023-03-18
2
7
 
3
8
  - Update the type declarations of HierarchicalNSW and BruteforceSearch along with recent changes.
@@ -528,10 +528,6 @@ private:
528
528
  free(index->linkLists_);
529
529
  index->linkLists_ = nullptr;
530
530
  }
531
- if (index->visited_list_pool_) {
532
- delete index->visited_list_pool_;
533
- index->visited_list_pool_ = nullptr;
534
- }
535
531
 
536
532
  try {
537
533
  index->loadIndex(filename, space);
@@ -85,10 +85,16 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
85
85
 
86
86
 
87
87
  void removePoint(labeltype cur_external) {
88
- size_t cur_c = dict_external_to_internal[cur_external];
88
+ std::unique_lock<std::mutex> lock(index_lock);
89
89
 
90
- dict_external_to_internal.erase(cur_external);
90
+ auto found = dict_external_to_internal.find(cur_external);
91
+ if (found == dict_external_to_internal.end()) {
92
+ return;
93
+ }
94
+
95
+ dict_external_to_internal.erase(found);
91
96
 
97
+ size_t cur_c = found->second;
92
98
  labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
93
99
  dict_external_to_internal[label] = cur_c;
94
100
  memcpy(data_ + size_per_element_ * cur_c,
@@ -107,7 +113,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
107
113
  dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
108
114
  labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
109
115
  if ((!isIdAllowed) || (*isIdAllowed)(label)) {
110
- topResults.push(std::pair<dist_t, labeltype>(dist, label));
116
+ topResults.emplace(dist, label);
111
117
  }
112
118
  }
113
119
  dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
@@ -116,7 +122,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
116
122
  if (dist <= lastdist) {
117
123
  labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
118
124
  if ((!isIdAllowed) || (*isIdAllowed)(label)) {
119
- topResults.push(std::pair<dist_t, labeltype>(dist, label));
125
+ topResults.emplace(dist, label);
120
126
  }
121
127
  if (topResults.size() > k)
122
128
  topResults.pop();
@@ -8,6 +8,7 @@
8
8
  #include <assert.h>
9
9
  #include <unordered_set>
10
10
  #include <list>
11
+ #include <memory>
11
12
 
12
13
  namespace hnswlib {
13
14
  typedef unsigned int tableint;
@@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
33
34
  double mult_{0.0}, revSize_{0.0};
34
35
  int maxlevel_{0};
35
36
 
36
- VisitedListPool *visited_list_pool_{nullptr};
37
+ std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};
37
38
 
38
39
  // Locks operations with element by label value
39
40
  mutable std::vector<std::mutex> label_op_locks_;
@@ -93,8 +94,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
93
94
  size_t ef_construction = 200,
94
95
  size_t random_seed = 100,
95
96
  bool allow_replace_deleted = false)
96
- : link_list_locks_(max_elements),
97
- label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
97
+ : label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
98
+ link_list_locks_(max_elements),
98
99
  element_levels_(max_elements),
99
100
  allow_replace_deleted_(allow_replace_deleted) {
100
101
  max_elements_ = max_elements;
@@ -102,7 +103,13 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
102
103
  data_size_ = s->get_data_size();
103
104
  fstdistfunc_ = s->get_dist_func();
104
105
  dist_func_param_ = s->get_dist_func_param();
105
- M_ = M;
106
+ if ( M <= 10000 ) {
107
+ M_ = M;
108
+ } else {
109
+ HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
110
+ HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
111
+ M_ = 10000;
112
+ }
106
113
  maxM_ = M_;
107
114
  maxM0_ = M_ * 2;
108
115
  ef_construction_ = std::max(ef_construction, M_);
@@ -123,7 +130,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
123
130
 
124
131
  cur_element_count = 0;
125
132
 
126
- visited_list_pool_ = new VisitedListPool(1, max_elements);
133
+ visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));
127
134
 
128
135
  // initializations for special treatment of the first node
129
136
  enterpoint_node_ = -1;
@@ -139,13 +146,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
139
146
 
140
147
 
141
148
  ~HierarchicalNSW() {
149
+ clear();
150
+ }
151
+
152
+ void clear() {
142
153
  free(data_level0_memory_);
154
+ data_level0_memory_ = nullptr;
143
155
  for (tableint i = 0; i < cur_element_count; i++) {
144
156
  if (element_levels_[i] > 0)
145
157
  free(linkLists_[i]);
146
158
  }
147
159
  free(linkLists_);
148
- delete visited_list_pool_;
160
+ linkLists_ = nullptr;
161
+ cur_element_count = 0;
162
+ visited_list_pool_.reset(nullptr);
149
163
  }
150
164
 
151
165
 
@@ -292,9 +306,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
292
306
  }
293
307
 
294
308
 
295
- template <bool has_deletions, bool collect_metrics = false>
309
+ // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
310
+ template <bool bare_bone_search = true, bool collect_metrics = false>
296
311
  std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
297
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
312
+ searchBaseLayerST(
313
+ tableint ep_id,
314
+ const void *data_point,
315
+ size_t ef,
316
+ BaseFilterFunctor* isIdAllowed = nullptr,
317
+ BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
298
318
  VisitedList *vl = visited_list_pool_->getFreeVisitedList();
299
319
  vl_type *visited_array = vl->mass;
300
320
  vl_type visited_array_tag = vl->curV;
@@ -303,10 +323,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
303
323
  std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
304
324
 
305
325
  dist_t lowerBound;
306
- if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
307
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
326
+ if (bare_bone_search ||
327
+ (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
328
+ char* ep_data = getDataByInternalId(ep_id);
329
+ dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
308
330
  lowerBound = dist;
309
331
  top_candidates.emplace(dist, ep_id);
332
+ if (!bare_bone_search && stop_condition) {
333
+ stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
334
+ }
310
335
  candidate_set.emplace(-dist, ep_id);
311
336
  } else {
312
337
  lowerBound = std::numeric_limits<dist_t>::max();
@@ -317,9 +342,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
317
342
 
318
343
  while (!candidate_set.empty()) {
319
344
  std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
345
+ dist_t candidate_dist = -current_node_pair.first;
320
346
 
321
- if ((-current_node_pair.first) > lowerBound &&
322
- (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
347
+ bool flag_stop_search;
348
+ if (bare_bone_search) {
349
+ flag_stop_search = candidate_dist > lowerBound;
350
+ } else {
351
+ if (stop_condition) {
352
+ flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
353
+ } else {
354
+ flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
355
+ }
356
+ }
357
+ if (flag_stop_search) {
323
358
  break;
324
359
  }
325
360
  candidate_set.pop();
@@ -354,7 +389,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
354
389
  char *currObj1 = (getDataByInternalId(candidate_id));
355
390
  dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
356
391
 
357
- if (top_candidates.size() < ef || lowerBound > dist) {
392
+ bool flag_consider_candidate;
393
+ if (!bare_bone_search && stop_condition) {
394
+ flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
395
+ } else {
396
+ flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
397
+ }
398
+
399
+ if (flag_consider_candidate) {
358
400
  candidate_set.emplace(-dist, candidate_id);
359
401
  #ifdef USE_SSE
360
402
  _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
@@ -362,11 +404,30 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
362
404
  _MM_HINT_T0); ////////////////////////
363
405
  #endif
364
406
 
365
- if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
407
+ if (bare_bone_search ||
408
+ (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
366
409
  top_candidates.emplace(dist, candidate_id);
410
+ if (!bare_bone_search && stop_condition) {
411
+ stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
412
+ }
413
+ }
367
414
 
368
- if (top_candidates.size() > ef)
415
+ bool flag_remove_extra = false;
416
+ if (!bare_bone_search && stop_condition) {
417
+ flag_remove_extra = stop_condition->should_remove_extra();
418
+ } else {
419
+ flag_remove_extra = top_candidates.size() > ef;
420
+ }
421
+ while (flag_remove_extra) {
422
+ tableint id = top_candidates.top().second;
369
423
  top_candidates.pop();
424
+ if (!bare_bone_search && stop_condition) {
425
+ stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
426
+ flag_remove_extra = stop_condition->should_remove_extra();
427
+ } else {
428
+ flag_remove_extra = top_candidates.size() > ef;
429
+ }
430
+ }
370
431
 
371
432
  if (!top_candidates.empty())
372
433
  lowerBound = top_candidates.top().first;
@@ -381,8 +442,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
381
442
 
382
443
 
383
444
  void getNeighborsByHeuristic2(
384
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
385
- const size_t M) {
445
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
446
+ const size_t M) {
386
447
  if (top_candidates.size() < M) {
387
448
  return;
388
449
  }
@@ -574,8 +635,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
574
635
  if (new_max_elements < cur_element_count)
575
636
  throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
576
637
 
577
- delete visited_list_pool_;
578
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
638
+ visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));
579
639
 
580
640
  element_levels_.resize(new_max_elements);
581
641
 
@@ -596,6 +656,32 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
596
656
  max_elements_ = new_max_elements;
597
657
  }
598
658
 
659
+ size_t indexFileSize() const {
660
+ size_t size = 0;
661
+ size += sizeof(offsetLevel0_);
662
+ size += sizeof(max_elements_);
663
+ size += sizeof(cur_element_count);
664
+ size += sizeof(size_data_per_element_);
665
+ size += sizeof(label_offset_);
666
+ size += sizeof(offsetData_);
667
+ size += sizeof(maxlevel_);
668
+ size += sizeof(enterpoint_node_);
669
+ size += sizeof(maxM_);
670
+
671
+ size += sizeof(maxM0_);
672
+ size += sizeof(M_);
673
+ size += sizeof(mult_);
674
+ size += sizeof(ef_construction_);
675
+
676
+ size += cur_element_count * size_data_per_element_;
677
+
678
+ for (size_t i = 0; i < cur_element_count; i++) {
679
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
680
+ size += sizeof(linkListSize);
681
+ size += linkListSize;
682
+ }
683
+ return size;
684
+ }
599
685
 
600
686
  void saveIndex(const std::string &location) {
601
687
  std::ofstream output(location, std::ios::binary);
@@ -634,6 +720,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
634
720
  if (!input.is_open())
635
721
  throw std::runtime_error("Cannot open file");
636
722
 
723
+ clear();
637
724
  // get file size:
638
725
  input.seekg(0, input.end);
639
726
  std::streampos total_filesize = input.tellg();
@@ -699,7 +786,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
699
786
  std::vector<std::mutex>(max_elements).swap(link_list_locks_);
700
787
  std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
701
788
 
702
- visited_list_pool_ = new VisitedListPool(1, max_elements);
789
+ visited_list_pool_.reset(new VisitedListPool(1, max_elements));
703
790
 
704
791
  linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
705
792
  if (linkLists_ == nullptr)
@@ -753,7 +840,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
753
840
  size_t dim = *((size_t *) dist_func_param_);
754
841
  std::vector<data_t> data;
755
842
  data_t* data_ptr = (data_t*) data_ptrv;
756
- for (int i = 0; i < dim; i++) {
843
+ for (size_t i = 0; i < dim; i++) {
757
844
  data.push_back(*data_ptr);
758
845
  data_ptr += 1;
759
846
  }
@@ -1217,11 +1304,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1217
1304
  }
1218
1305
 
1219
1306
  std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1220
- if (num_deleted_) {
1221
- top_candidates = searchBaseLayerST<true, true>(
1307
+ bool bare_bone_search = !num_deleted_ && !isIdAllowed;
1308
+ if (bare_bone_search) {
1309
+ top_candidates = searchBaseLayerST<true>(
1222
1310
  currObj, query_data, std::max(ef_, k), isIdAllowed);
1223
1311
  } else {
1224
- top_candidates = searchBaseLayerST<false, true>(
1312
+ top_candidates = searchBaseLayerST<false>(
1225
1313
  currObj, query_data, std::max(ef_, k), isIdAllowed);
1226
1314
  }
1227
1315
 
@@ -1237,6 +1325,60 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1237
1325
  }
1238
1326
 
1239
1327
 
1328
+ std::vector<std::pair<dist_t, labeltype >>
1329
+ searchStopConditionClosest(
1330
+ const void *query_data,
1331
+ BaseSearchStopCondition<dist_t>& stop_condition,
1332
+ BaseFilterFunctor* isIdAllowed = nullptr) const {
1333
+ std::vector<std::pair<dist_t, labeltype >> result;
1334
+ if (cur_element_count == 0) return result;
1335
+
1336
+ tableint currObj = enterpoint_node_;
1337
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1338
+
1339
+ for (int level = maxlevel_; level > 0; level--) {
1340
+ bool changed = true;
1341
+ while (changed) {
1342
+ changed = false;
1343
+ unsigned int *data;
1344
+
1345
+ data = (unsigned int *) get_linklist(currObj, level);
1346
+ int size = getListCount(data);
1347
+ metric_hops++;
1348
+ metric_distance_computations+=size;
1349
+
1350
+ tableint *datal = (tableint *) (data + 1);
1351
+ for (int i = 0; i < size; i++) {
1352
+ tableint cand = datal[i];
1353
+ if (cand < 0 || cand > max_elements_)
1354
+ throw std::runtime_error("cand error");
1355
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1356
+
1357
+ if (d < curdist) {
1358
+ curdist = d;
1359
+ currObj = cand;
1360
+ changed = true;
1361
+ }
1362
+ }
1363
+ }
1364
+ }
1365
+
1366
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1367
+ top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition);
1368
+
1369
+ size_t sz = top_candidates.size();
1370
+ result.resize(sz);
1371
+ while (!top_candidates.empty()) {
1372
+ result[--sz] = top_candidates.top();
1373
+ top_candidates.pop();
1374
+ }
1375
+
1376
+ stop_condition.filter_results(result);
1377
+
1378
+ return result;
1379
+ }
1380
+
1381
+
1240
1382
  void checkIntegrity() {
1241
1383
  int connections_checked = 0;
1242
1384
  std::vector <int > inbound_connections_num(cur_element_count, 0);
@@ -1247,7 +1389,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1247
1389
  tableint *data = (tableint *) (ll_cur + 1);
1248
1390
  std::unordered_set<tableint> s;
1249
1391
  for (int j = 0; j < size; j++) {
1250
- assert(data[j] > 0);
1251
1392
  assert(data[j] < cur_element_count);
1252
1393
  assert(data[j] != i);
1253
1394
  inbound_connections_num[data[j]]++;
@@ -1,4 +1,13 @@
1
1
  #pragma once
2
+
3
+ // https://github.com/nmslib/hnswlib/pull/508
4
+ // This allows others to provide their own error stream (e.g. RcppHNSW)
5
+ #ifndef HNSWLIB_ERR_OVERRIDE
6
+ #define HNSWERR std::cerr
7
+ #else
8
+ #define HNSWERR HNSWLIB_ERR_OVERRIDE
9
+ #endif
10
+
2
11
  #ifndef NO_MANUAL_VECTORIZATION
3
12
  #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
4
13
  #define USE_SSE
@@ -15,7 +24,7 @@
15
24
  #ifdef _MSC_VER
16
25
  #include <intrin.h>
17
26
  #include <stdexcept>
18
- void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
27
+ static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
19
28
  __cpuidex(out, eax, ecx);
20
29
  }
21
30
  static __int64 xgetbv(unsigned int x) {
@@ -122,6 +131,24 @@ class BaseFilterFunctor {
122
131
  virtual ~BaseFilterFunctor() {};
123
132
  };
124
133
 
134
+ template<typename dist_t>
135
+ class BaseSearchStopCondition {
136
+ public:
137
+ virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0;
138
+
139
+ virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0;
140
+
141
+ virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0;
142
+
143
+ virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0;
144
+
145
+ virtual bool should_remove_extra() = 0;
146
+
147
+ virtual void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) = 0;
148
+
149
+ virtual ~BaseSearchStopCondition() {}
150
+ };
151
+
125
152
  template <typename T>
126
153
  class pairGreater {
127
154
  public:
@@ -196,5 +223,6 @@ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t
196
223
 
197
224
  #include "space_l2.h"
198
225
  #include "space_ip.h"
226
+ #include "stop_condition.h"
199
227
  #include "bruteforce.h"
200
228
  #include "hnswalg.h"
@@ -157,19 +157,44 @@ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void
157
157
 
158
158
  __m512 sum512 = _mm512_set1_ps(0);
159
159
 
160
- while (pVect1 < pEnd1) {
161
- //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
160
+ size_t loop = qty16 / 4;
162
161
 
162
+ while (loop--) {
163
163
  __m512 v1 = _mm512_loadu_ps(pVect1);
164
- pVect1 += 16;
165
164
  __m512 v2 = _mm512_loadu_ps(pVect2);
165
+ pVect1 += 16;
166
+ pVect2 += 16;
167
+
168
+ __m512 v3 = _mm512_loadu_ps(pVect1);
169
+ __m512 v4 = _mm512_loadu_ps(pVect2);
170
+ pVect1 += 16;
171
+ pVect2 += 16;
172
+
173
+ __m512 v5 = _mm512_loadu_ps(pVect1);
174
+ __m512 v6 = _mm512_loadu_ps(pVect2);
175
+ pVect1 += 16;
166
176
  pVect2 += 16;
167
- sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
177
+
178
+ __m512 v7 = _mm512_loadu_ps(pVect1);
179
+ __m512 v8 = _mm512_loadu_ps(pVect2);
180
+ pVect1 += 16;
181
+ pVect2 += 16;
182
+
183
+ sum512 = _mm512_fmadd_ps(v1, v2, sum512);
184
+ sum512 = _mm512_fmadd_ps(v3, v4, sum512);
185
+ sum512 = _mm512_fmadd_ps(v5, v6, sum512);
186
+ sum512 = _mm512_fmadd_ps(v7, v8, sum512);
168
187
  }
169
188
 
170
- _mm512_store_ps(TmpRes, sum512);
171
- float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
189
+ while (pVect1 < pEnd1) {
190
+ __m512 v1 = _mm512_loadu_ps(pVect1);
191
+ __m512 v2 = _mm512_loadu_ps(pVect2);
192
+ pVect1 += 16;
193
+ pVect2 += 16;
194
+ sum512 = _mm512_fmadd_ps(v1, v2, sum512);
195
+ }
172
196
 
197
+ float sum = _mm512_reduce_add_ps(sum512);
173
198
  return sum;
174
199
  }
175
200
 
@@ -0,0 +1,276 @@
1
+ #pragma once
2
+ #include "space_l2.h"
3
+ #include "space_ip.h"
4
+ #include <assert.h>
5
+ #include <unordered_map>
6
+
7
+ namespace hnswlib {
8
+
9
+ template<typename DOCIDTYPE>
10
+ class BaseMultiVectorSpace : public SpaceInterface<float> {
11
+ public:
12
+ virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0;
13
+
14
+ virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0;
15
+ };
16
+
17
+
18
+ template<typename DOCIDTYPE>
19
+ class MultiVectorL2Space : public BaseMultiVectorSpace<DOCIDTYPE> {
20
+ DISTFUNC<float> fstdistfunc_;
21
+ size_t data_size_;
22
+ size_t vector_size_;
23
+ size_t dim_;
24
+
25
+ public:
26
+ MultiVectorL2Space(size_t dim) {
27
+ fstdistfunc_ = L2Sqr;
28
+ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
29
+ #if defined(USE_AVX512)
30
+ if (AVX512Capable())
31
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
32
+ else if (AVXCapable())
33
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
34
+ #elif defined(USE_AVX)
35
+ if (AVXCapable())
36
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
37
+ #endif
38
+
39
+ if (dim % 16 == 0)
40
+ fstdistfunc_ = L2SqrSIMD16Ext;
41
+ else if (dim % 4 == 0)
42
+ fstdistfunc_ = L2SqrSIMD4Ext;
43
+ else if (dim > 16)
44
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
45
+ else if (dim > 4)
46
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
47
+ #endif
48
+ dim_ = dim;
49
+ vector_size_ = dim * sizeof(float);
50
+ data_size_ = vector_size_ + sizeof(DOCIDTYPE);
51
+ }
52
+
53
+ size_t get_data_size() override {
54
+ return data_size_;
55
+ }
56
+
57
+ DISTFUNC<float> get_dist_func() override {
58
+ return fstdistfunc_;
59
+ }
60
+
61
+ void *get_dist_func_param() override {
62
+ return &dim_;
63
+ }
64
+
65
+ DOCIDTYPE get_doc_id(const void *datapoint) override {
66
+ return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
67
+ }
68
+
69
+ void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
70
+ *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
71
+ }
72
+
73
+ ~MultiVectorL2Space() {}
74
+ };
75
+
76
+
77
+ template<typename DOCIDTYPE>
78
+ class MultiVectorInnerProductSpace : public BaseMultiVectorSpace<DOCIDTYPE> {
79
+ DISTFUNC<float> fstdistfunc_;
80
+ size_t data_size_;
81
+ size_t vector_size_;
82
+ size_t dim_;
83
+
84
+ public:
85
+ MultiVectorInnerProductSpace(size_t dim) {
86
+ fstdistfunc_ = InnerProductDistance;
87
+ #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
88
+ #if defined(USE_AVX512)
89
+ if (AVX512Capable()) {
90
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
91
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
92
+ } else if (AVXCapable()) {
93
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
94
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
95
+ }
96
+ #elif defined(USE_AVX)
97
+ if (AVXCapable()) {
98
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
99
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
100
+ }
101
+ #endif
102
+ #if defined(USE_AVX)
103
+ if (AVXCapable()) {
104
+ InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
105
+ InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
106
+ }
107
+ #endif
108
+
109
+ if (dim % 16 == 0)
110
+ fstdistfunc_ = InnerProductDistanceSIMD16Ext;
111
+ else if (dim % 4 == 0)
112
+ fstdistfunc_ = InnerProductDistanceSIMD4Ext;
113
+ else if (dim > 16)
114
+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
115
+ else if (dim > 4)
116
+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
117
+ #endif
118
+ vector_size_ = dim * sizeof(float);
119
+ data_size_ = vector_size_ + sizeof(DOCIDTYPE);
120
+ }
121
+
122
+ size_t get_data_size() override {
123
+ return data_size_;
124
+ }
125
+
126
+ DISTFUNC<float> get_dist_func() override {
127
+ return fstdistfunc_;
128
+ }
129
+
130
+ void *get_dist_func_param() override {
131
+ return &dim_;
132
+ }
133
+
134
+ DOCIDTYPE get_doc_id(const void *datapoint) override {
135
+ return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
136
+ }
137
+
138
+ void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
139
+ *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
140
+ }
141
+
142
+ ~MultiVectorInnerProductSpace() {}
143
+ };
144
+
145
+
146
+ template<typename DOCIDTYPE, typename dist_t>
147
+ class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
148
+ size_t curr_num_docs_;
149
+ size_t num_docs_to_search_;
150
+ size_t ef_collection_;
151
+ std::unordered_map<DOCIDTYPE, size_t> doc_counter_;
152
+ std::priority_queue<std::pair<dist_t, DOCIDTYPE>> search_results_;
153
+ BaseMultiVectorSpace<DOCIDTYPE>& space_;
154
+
155
+ public:
156
+ MultiVectorSearchStopCondition(
157
+ BaseMultiVectorSpace<DOCIDTYPE>& space,
158
+ size_t num_docs_to_search,
159
+ size_t ef_collection = 10)
160
+ : space_(space) {
161
+ curr_num_docs_ = 0;
162
+ num_docs_to_search_ = num_docs_to_search;
163
+ ef_collection_ = std::max(ef_collection, num_docs_to_search);
164
+ }
165
+
166
+ void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
167
+ DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
168
+ if (doc_counter_[doc_id] == 0) {
169
+ curr_num_docs_ += 1;
170
+ }
171
+ search_results_.emplace(dist, doc_id);
172
+ doc_counter_[doc_id] += 1;
173
+ }
174
+
175
+ void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
176
+ DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
177
+ doc_counter_[doc_id] -= 1;
178
+ if (doc_counter_[doc_id] == 0) {
179
+ curr_num_docs_ -= 1;
180
+ }
181
+ search_results_.pop();
182
+ }
183
+
184
+ bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
185
+ bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_;
186
+ return stop_search;
187
+ }
188
+
189
+ bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
190
+ bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist;
191
+ return flag_consider_candidate;
192
+ }
193
+
194
+ bool should_remove_extra() override {
195
+ bool flag_remove_extra = curr_num_docs_ > ef_collection_;
196
+ return flag_remove_extra;
197
+ }
198
+
199
+ void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
200
+ while (curr_num_docs_ > num_docs_to_search_) {
201
+ dist_t dist_cand = candidates.back().first;
202
+ dist_t dist_res = search_results_.top().first;
203
+ assert(dist_cand == dist_res);
204
+ DOCIDTYPE doc_id = search_results_.top().second;
205
+ doc_counter_[doc_id] -= 1;
206
+ if (doc_counter_[doc_id] == 0) {
207
+ curr_num_docs_ -= 1;
208
+ }
209
+ search_results_.pop();
210
+ candidates.pop_back();
211
+ }
212
+ }
213
+
214
+ ~MultiVectorSearchStopCondition() {}
215
+ };
216
+
217
+
218
+ template<typename dist_t>
219
+ class EpsilonSearchStopCondition : public BaseSearchStopCondition<dist_t> {
220
+ float epsilon_;
221
+ size_t min_num_candidates_;
222
+ size_t max_num_candidates_;
223
+ size_t curr_num_items_;
224
+
225
+ public:
226
+ EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) {
227
+ assert(min_num_candidates <= max_num_candidates);
228
+ epsilon_ = epsilon;
229
+ min_num_candidates_ = min_num_candidates;
230
+ max_num_candidates_ = max_num_candidates;
231
+ curr_num_items_ = 0;
232
+ }
233
+
234
+ void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
235
+ curr_num_items_ += 1;
236
+ }
237
+
238
+ void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
239
+ curr_num_items_ -= 1;
240
+ }
241
+
242
+ bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
243
+ if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) {
244
+ // new candidate can't improve found results
245
+ return true;
246
+ }
247
+ if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) {
248
+ // new candidate is out of epsilon region and
249
+ // minimum number of candidates is checked
250
+ return true;
251
+ }
252
+ return false;
253
+ }
254
+
255
+ bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
256
+ bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist;
257
+ return flag_consider_candidate;
258
+ }
259
+
260
+ bool should_remove_extra() {
261
+ bool flag_remove_extra = curr_num_items_ > max_num_candidates_;
262
+ return flag_remove_extra;
263
+ }
264
+
265
+ void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
266
+ while (!candidates.empty() && candidates.back().first > epsilon_) {
267
+ candidates.pop_back();
268
+ }
269
+ while (candidates.size() > max_num_candidates_) {
270
+ candidates.pop_back();
271
+ }
272
+ }
273
+
274
+ ~EpsilonSearchStopCondition() {}
275
+ };
276
+ } // namespace hnswlib
@@ -3,8 +3,8 @@
3
3
  # Hnswlib.rb provides Ruby bindings for the Hnswlib.
4
4
  module Hnswlib
5
5
  # The version of Hnswlib.rb you install.
6
- VERSION = '0.8.1'
6
+ VERSION = '0.9.0'
7
7
 
8
8
  # The version of Hnswlib included with gem.
9
- HSWLIB_VERSION = '0.7.0'
9
+ HSWLIB_VERSION = '0.8.0'
10
10
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: hnswlib
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.1
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-03-18 00:00:00.000000000 Z
11
+ date: 2023-12-16 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: Hnswlib.rb provides Ruby bindings for the Hnswlib.
14
14
  email:
@@ -30,6 +30,7 @@ files:
30
30
  - ext/hnswlib/src/hnswlib.h
31
31
  - ext/hnswlib/src/space_ip.h
32
32
  - ext/hnswlib/src/space_l2.h
33
+ - ext/hnswlib/src/stop_condition.h
33
34
  - ext/hnswlib/src/visited_list_pool.h
34
35
  - lib/hnswlib.rb
35
36
  - lib/hnswlib/version.rb
@@ -41,6 +42,7 @@ metadata:
41
42
  homepage_uri: https://github.com/yoshoku/hnswlib.rb
42
43
  source_code_uri: https://github.com/yoshoku/hnswlib.rb
43
44
  changelog_uri: https://github.com/yoshoku/hnswlib.rb/blob/main/CHANGELOG.md
45
+ documentation_uri: https://yoshoku.github.io/hnswlib.rb/doc/
44
46
  rubygems_mfa_required: 'true'
45
47
  post_install_message:
46
48
  rdoc_options: []
@@ -57,7 +59,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
57
59
  - !ruby/object:Gem::Version
58
60
  version: '0'
59
61
  requirements: []
60
- rubygems_version: 3.3.26
62
+ rubygems_version: 3.4.22
61
63
  signing_key:
62
64
  specification_version: 4
63
65
  summary: Ruby bindings for the Hnswlib.