hnswlib 0.8.1 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/hnswlib/hnswlibext.hpp +0 -4
- data/ext/hnswlib/src/bruteforce.h +10 -4
- data/ext/hnswlib/src/hnswalg.h +166 -25
- data/ext/hnswlib/src/hnswlib.h +29 -1
- data/ext/hnswlib/src/space_ip.h +31 -6
- data/ext/hnswlib/src/stop_condition.h +276 -0
- data/lib/hnswlib/version.rb +2 -2
- metadata +5 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 39458b40736b4a330a0c1769a3adcf37d9d40614b1270d33885932244e943b74
|
4
|
+
data.tar.gz: d49bf8e158c55235fdeca0656cd3ddab6a07e5abc1468eb8917b21f9baf5d89c
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ec8560f2b7aa9bb5df098b9c863e4602465a42e8ea27a19be55b9d62e9ba0d32fa70ccbcb13323b771370df5ded9370f9f91d17c8d63aecc18d2b45e29df0149
|
7
|
+
data.tar.gz: c9b434fb313f41e2f0570070944eb6eb3b4432a65fce2c63ca6592ab1de30f9b04834d87e1bd40db465dceea0eaac2279378eda537f6b7d51c2bdb0aa06e1ab6
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,8 @@
|
|
1
|
+
## [[0.9.0](https://github.com/yoshoku/hnswlib.rb/compare/v0.8.1...v0.9.0)] - 2023-12-16
|
2
|
+
|
3
|
+
- Update bundled hnswlib version to 0.8.0.
|
4
|
+
- Multi-vector document search and epsilon search, which are added only to the C++ version, are not supported. These features will be supported in future release.
|
5
|
+
|
1
6
|
## [0.8.1] - 2023-03-18
|
2
7
|
|
3
8
|
- Update the type declarations of HierarchicalNSW and BruteforceSearch along with recent changes.
|
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -528,10 +528,6 @@ private:
|
|
528
528
|
free(index->linkLists_);
|
529
529
|
index->linkLists_ = nullptr;
|
530
530
|
}
|
531
|
-
if (index->visited_list_pool_) {
|
532
|
-
delete index->visited_list_pool_;
|
533
|
-
index->visited_list_pool_ = nullptr;
|
534
|
-
}
|
535
531
|
|
536
532
|
try {
|
537
533
|
index->loadIndex(filename, space);
|
@@ -85,10 +85,16 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
|
85
85
|
|
86
86
|
|
87
87
|
void removePoint(labeltype cur_external) {
|
88
|
-
|
88
|
+
std::unique_lock<std::mutex> lock(index_lock);
|
89
89
|
|
90
|
-
dict_external_to_internal.
|
90
|
+
auto found = dict_external_to_internal.find(cur_external);
|
91
|
+
if (found == dict_external_to_internal.end()) {
|
92
|
+
return;
|
93
|
+
}
|
94
|
+
|
95
|
+
dict_external_to_internal.erase(found);
|
91
96
|
|
97
|
+
size_t cur_c = found->second;
|
92
98
|
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
93
99
|
dict_external_to_internal[label] = cur_c;
|
94
100
|
memcpy(data_ + size_per_element_ * cur_c,
|
@@ -107,7 +113,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
|
107
113
|
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
108
114
|
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
|
109
115
|
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
110
|
-
topResults.
|
116
|
+
topResults.emplace(dist, label);
|
111
117
|
}
|
112
118
|
}
|
113
119
|
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
|
@@ -116,7 +122,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
|
116
122
|
if (dist <= lastdist) {
|
117
123
|
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
|
118
124
|
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
119
|
-
topResults.
|
125
|
+
topResults.emplace(dist, label);
|
120
126
|
}
|
121
127
|
if (topResults.size() > k)
|
122
128
|
topResults.pop();
|
data/ext/hnswlib/src/hnswalg.h
CHANGED
@@ -8,6 +8,7 @@
|
|
8
8
|
#include <assert.h>
|
9
9
|
#include <unordered_set>
|
10
10
|
#include <list>
|
11
|
+
#include <memory>
|
11
12
|
|
12
13
|
namespace hnswlib {
|
13
14
|
typedef unsigned int tableint;
|
@@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
33
34
|
double mult_{0.0}, revSize_{0.0};
|
34
35
|
int maxlevel_{0};
|
35
36
|
|
36
|
-
VisitedListPool
|
37
|
+
std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};
|
37
38
|
|
38
39
|
// Locks operations with element by label value
|
39
40
|
mutable std::vector<std::mutex> label_op_locks_;
|
@@ -93,8 +94,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
93
94
|
size_t ef_construction = 200,
|
94
95
|
size_t random_seed = 100,
|
95
96
|
bool allow_replace_deleted = false)
|
96
|
-
:
|
97
|
-
|
97
|
+
: label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
|
98
|
+
link_list_locks_(max_elements),
|
98
99
|
element_levels_(max_elements),
|
99
100
|
allow_replace_deleted_(allow_replace_deleted) {
|
100
101
|
max_elements_ = max_elements;
|
@@ -102,7 +103,13 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
102
103
|
data_size_ = s->get_data_size();
|
103
104
|
fstdistfunc_ = s->get_dist_func();
|
104
105
|
dist_func_param_ = s->get_dist_func_param();
|
105
|
-
|
106
|
+
if ( M <= 10000 ) {
|
107
|
+
M_ = M;
|
108
|
+
} else {
|
109
|
+
HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
|
110
|
+
HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
|
111
|
+
M_ = 10000;
|
112
|
+
}
|
106
113
|
maxM_ = M_;
|
107
114
|
maxM0_ = M_ * 2;
|
108
115
|
ef_construction_ = std::max(ef_construction, M_);
|
@@ -123,7 +130,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
123
130
|
|
124
131
|
cur_element_count = 0;
|
125
132
|
|
126
|
-
visited_list_pool_ = new VisitedListPool(1, max_elements);
|
133
|
+
visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));
|
127
134
|
|
128
135
|
// initializations for special treatment of the first node
|
129
136
|
enterpoint_node_ = -1;
|
@@ -139,13 +146,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
139
146
|
|
140
147
|
|
141
148
|
~HierarchicalNSW() {
|
149
|
+
clear();
|
150
|
+
}
|
151
|
+
|
152
|
+
void clear() {
|
142
153
|
free(data_level0_memory_);
|
154
|
+
data_level0_memory_ = nullptr;
|
143
155
|
for (tableint i = 0; i < cur_element_count; i++) {
|
144
156
|
if (element_levels_[i] > 0)
|
145
157
|
free(linkLists_[i]);
|
146
158
|
}
|
147
159
|
free(linkLists_);
|
148
|
-
|
160
|
+
linkLists_ = nullptr;
|
161
|
+
cur_element_count = 0;
|
162
|
+
visited_list_pool_.reset(nullptr);
|
149
163
|
}
|
150
164
|
|
151
165
|
|
@@ -292,9 +306,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
292
306
|
}
|
293
307
|
|
294
308
|
|
295
|
-
|
309
|
+
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
|
310
|
+
template <bool bare_bone_search = true, bool collect_metrics = false>
|
296
311
|
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
|
297
|
-
searchBaseLayerST(
|
312
|
+
searchBaseLayerST(
|
313
|
+
tableint ep_id,
|
314
|
+
const void *data_point,
|
315
|
+
size_t ef,
|
316
|
+
BaseFilterFunctor* isIdAllowed = nullptr,
|
317
|
+
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
|
298
318
|
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
|
299
319
|
vl_type *visited_array = vl->mass;
|
300
320
|
vl_type visited_array_tag = vl->curV;
|
@@ -303,10 +323,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
303
323
|
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
|
304
324
|
|
305
325
|
dist_t lowerBound;
|
306
|
-
if (
|
307
|
-
|
326
|
+
if (bare_bone_search ||
|
327
|
+
(!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
|
328
|
+
char* ep_data = getDataByInternalId(ep_id);
|
329
|
+
dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
|
308
330
|
lowerBound = dist;
|
309
331
|
top_candidates.emplace(dist, ep_id);
|
332
|
+
if (!bare_bone_search && stop_condition) {
|
333
|
+
stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
|
334
|
+
}
|
310
335
|
candidate_set.emplace(-dist, ep_id);
|
311
336
|
} else {
|
312
337
|
lowerBound = std::numeric_limits<dist_t>::max();
|
@@ -317,9 +342,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
317
342
|
|
318
343
|
while (!candidate_set.empty()) {
|
319
344
|
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
|
345
|
+
dist_t candidate_dist = -current_node_pair.first;
|
320
346
|
|
321
|
-
|
322
|
-
|
347
|
+
bool flag_stop_search;
|
348
|
+
if (bare_bone_search) {
|
349
|
+
flag_stop_search = candidate_dist > lowerBound;
|
350
|
+
} else {
|
351
|
+
if (stop_condition) {
|
352
|
+
flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
|
353
|
+
} else {
|
354
|
+
flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
|
355
|
+
}
|
356
|
+
}
|
357
|
+
if (flag_stop_search) {
|
323
358
|
break;
|
324
359
|
}
|
325
360
|
candidate_set.pop();
|
@@ -354,7 +389,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
354
389
|
char *currObj1 = (getDataByInternalId(candidate_id));
|
355
390
|
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
|
356
391
|
|
357
|
-
|
392
|
+
bool flag_consider_candidate;
|
393
|
+
if (!bare_bone_search && stop_condition) {
|
394
|
+
flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
|
395
|
+
} else {
|
396
|
+
flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
|
397
|
+
}
|
398
|
+
|
399
|
+
if (flag_consider_candidate) {
|
358
400
|
candidate_set.emplace(-dist, candidate_id);
|
359
401
|
#ifdef USE_SSE
|
360
402
|
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
|
@@ -362,11 +404,30 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
362
404
|
_MM_HINT_T0); ////////////////////////
|
363
405
|
#endif
|
364
406
|
|
365
|
-
if (
|
407
|
+
if (bare_bone_search ||
|
408
|
+
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
|
366
409
|
top_candidates.emplace(dist, candidate_id);
|
410
|
+
if (!bare_bone_search && stop_condition) {
|
411
|
+
stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
|
412
|
+
}
|
413
|
+
}
|
367
414
|
|
368
|
-
|
415
|
+
bool flag_remove_extra = false;
|
416
|
+
if (!bare_bone_search && stop_condition) {
|
417
|
+
flag_remove_extra = stop_condition->should_remove_extra();
|
418
|
+
} else {
|
419
|
+
flag_remove_extra = top_candidates.size() > ef;
|
420
|
+
}
|
421
|
+
while (flag_remove_extra) {
|
422
|
+
tableint id = top_candidates.top().second;
|
369
423
|
top_candidates.pop();
|
424
|
+
if (!bare_bone_search && stop_condition) {
|
425
|
+
stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
|
426
|
+
flag_remove_extra = stop_condition->should_remove_extra();
|
427
|
+
} else {
|
428
|
+
flag_remove_extra = top_candidates.size() > ef;
|
429
|
+
}
|
430
|
+
}
|
370
431
|
|
371
432
|
if (!top_candidates.empty())
|
372
433
|
lowerBound = top_candidates.top().first;
|
@@ -381,8 +442,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
381
442
|
|
382
443
|
|
383
444
|
void getNeighborsByHeuristic2(
|
384
|
-
|
385
|
-
|
445
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
|
446
|
+
const size_t M) {
|
386
447
|
if (top_candidates.size() < M) {
|
387
448
|
return;
|
388
449
|
}
|
@@ -574,8 +635,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
574
635
|
if (new_max_elements < cur_element_count)
|
575
636
|
throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
|
576
637
|
|
577
|
-
|
578
|
-
visited_list_pool_ = new VisitedListPool(1, new_max_elements);
|
638
|
+
visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));
|
579
639
|
|
580
640
|
element_levels_.resize(new_max_elements);
|
581
641
|
|
@@ -596,6 +656,32 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
596
656
|
max_elements_ = new_max_elements;
|
597
657
|
}
|
598
658
|
|
659
|
+
size_t indexFileSize() const {
|
660
|
+
size_t size = 0;
|
661
|
+
size += sizeof(offsetLevel0_);
|
662
|
+
size += sizeof(max_elements_);
|
663
|
+
size += sizeof(cur_element_count);
|
664
|
+
size += sizeof(size_data_per_element_);
|
665
|
+
size += sizeof(label_offset_);
|
666
|
+
size += sizeof(offsetData_);
|
667
|
+
size += sizeof(maxlevel_);
|
668
|
+
size += sizeof(enterpoint_node_);
|
669
|
+
size += sizeof(maxM_);
|
670
|
+
|
671
|
+
size += sizeof(maxM0_);
|
672
|
+
size += sizeof(M_);
|
673
|
+
size += sizeof(mult_);
|
674
|
+
size += sizeof(ef_construction_);
|
675
|
+
|
676
|
+
size += cur_element_count * size_data_per_element_;
|
677
|
+
|
678
|
+
for (size_t i = 0; i < cur_element_count; i++) {
|
679
|
+
unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
|
680
|
+
size += sizeof(linkListSize);
|
681
|
+
size += linkListSize;
|
682
|
+
}
|
683
|
+
return size;
|
684
|
+
}
|
599
685
|
|
600
686
|
void saveIndex(const std::string &location) {
|
601
687
|
std::ofstream output(location, std::ios::binary);
|
@@ -634,6 +720,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
634
720
|
if (!input.is_open())
|
635
721
|
throw std::runtime_error("Cannot open file");
|
636
722
|
|
723
|
+
clear();
|
637
724
|
// get file size:
|
638
725
|
input.seekg(0, input.end);
|
639
726
|
std::streampos total_filesize = input.tellg();
|
@@ -699,7 +786,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
699
786
|
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
|
700
787
|
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
|
701
788
|
|
702
|
-
visited_list_pool_
|
789
|
+
visited_list_pool_.reset(new VisitedListPool(1, max_elements));
|
703
790
|
|
704
791
|
linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
|
705
792
|
if (linkLists_ == nullptr)
|
@@ -753,7 +840,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
753
840
|
size_t dim = *((size_t *) dist_func_param_);
|
754
841
|
std::vector<data_t> data;
|
755
842
|
data_t* data_ptr = (data_t*) data_ptrv;
|
756
|
-
for (
|
843
|
+
for (size_t i = 0; i < dim; i++) {
|
757
844
|
data.push_back(*data_ptr);
|
758
845
|
data_ptr += 1;
|
759
846
|
}
|
@@ -1217,11 +1304,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
1217
1304
|
}
|
1218
1305
|
|
1219
1306
|
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
|
1220
|
-
|
1221
|
-
|
1307
|
+
bool bare_bone_search = !num_deleted_ && !isIdAllowed;
|
1308
|
+
if (bare_bone_search) {
|
1309
|
+
top_candidates = searchBaseLayerST<true>(
|
1222
1310
|
currObj, query_data, std::max(ef_, k), isIdAllowed);
|
1223
1311
|
} else {
|
1224
|
-
top_candidates = searchBaseLayerST<false
|
1312
|
+
top_candidates = searchBaseLayerST<false>(
|
1225
1313
|
currObj, query_data, std::max(ef_, k), isIdAllowed);
|
1226
1314
|
}
|
1227
1315
|
|
@@ -1237,6 +1325,60 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
1237
1325
|
}
|
1238
1326
|
|
1239
1327
|
|
1328
|
+
std::vector<std::pair<dist_t, labeltype >>
|
1329
|
+
searchStopConditionClosest(
|
1330
|
+
const void *query_data,
|
1331
|
+
BaseSearchStopCondition<dist_t>& stop_condition,
|
1332
|
+
BaseFilterFunctor* isIdAllowed = nullptr) const {
|
1333
|
+
std::vector<std::pair<dist_t, labeltype >> result;
|
1334
|
+
if (cur_element_count == 0) return result;
|
1335
|
+
|
1336
|
+
tableint currObj = enterpoint_node_;
|
1337
|
+
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
|
1338
|
+
|
1339
|
+
for (int level = maxlevel_; level > 0; level--) {
|
1340
|
+
bool changed = true;
|
1341
|
+
while (changed) {
|
1342
|
+
changed = false;
|
1343
|
+
unsigned int *data;
|
1344
|
+
|
1345
|
+
data = (unsigned int *) get_linklist(currObj, level);
|
1346
|
+
int size = getListCount(data);
|
1347
|
+
metric_hops++;
|
1348
|
+
metric_distance_computations+=size;
|
1349
|
+
|
1350
|
+
tableint *datal = (tableint *) (data + 1);
|
1351
|
+
for (int i = 0; i < size; i++) {
|
1352
|
+
tableint cand = datal[i];
|
1353
|
+
if (cand < 0 || cand > max_elements_)
|
1354
|
+
throw std::runtime_error("cand error");
|
1355
|
+
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
|
1356
|
+
|
1357
|
+
if (d < curdist) {
|
1358
|
+
curdist = d;
|
1359
|
+
currObj = cand;
|
1360
|
+
changed = true;
|
1361
|
+
}
|
1362
|
+
}
|
1363
|
+
}
|
1364
|
+
}
|
1365
|
+
|
1366
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
|
1367
|
+
top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition);
|
1368
|
+
|
1369
|
+
size_t sz = top_candidates.size();
|
1370
|
+
result.resize(sz);
|
1371
|
+
while (!top_candidates.empty()) {
|
1372
|
+
result[--sz] = top_candidates.top();
|
1373
|
+
top_candidates.pop();
|
1374
|
+
}
|
1375
|
+
|
1376
|
+
stop_condition.filter_results(result);
|
1377
|
+
|
1378
|
+
return result;
|
1379
|
+
}
|
1380
|
+
|
1381
|
+
|
1240
1382
|
void checkIntegrity() {
|
1241
1383
|
int connections_checked = 0;
|
1242
1384
|
std::vector <int > inbound_connections_num(cur_element_count, 0);
|
@@ -1247,7 +1389,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
1247
1389
|
tableint *data = (tableint *) (ll_cur + 1);
|
1248
1390
|
std::unordered_set<tableint> s;
|
1249
1391
|
for (int j = 0; j < size; j++) {
|
1250
|
-
assert(data[j] > 0);
|
1251
1392
|
assert(data[j] < cur_element_count);
|
1252
1393
|
assert(data[j] != i);
|
1253
1394
|
inbound_connections_num[data[j]]++;
|
data/ext/hnswlib/src/hnswlib.h
CHANGED
@@ -1,4 +1,13 @@
|
|
1
1
|
#pragma once
|
2
|
+
|
3
|
+
// https://github.com/nmslib/hnswlib/pull/508
|
4
|
+
// This allows others to provide their own error stream (e.g. RcppHNSW)
|
5
|
+
#ifndef HNSWLIB_ERR_OVERRIDE
|
6
|
+
#define HNSWERR std::cerr
|
7
|
+
#else
|
8
|
+
#define HNSWERR HNSWLIB_ERR_OVERRIDE
|
9
|
+
#endif
|
10
|
+
|
2
11
|
#ifndef NO_MANUAL_VECTORIZATION
|
3
12
|
#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
|
4
13
|
#define USE_SSE
|
@@ -15,7 +24,7 @@
|
|
15
24
|
#ifdef _MSC_VER
|
16
25
|
#include <intrin.h>
|
17
26
|
#include <stdexcept>
|
18
|
-
void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
|
27
|
+
static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
|
19
28
|
__cpuidex(out, eax, ecx);
|
20
29
|
}
|
21
30
|
static __int64 xgetbv(unsigned int x) {
|
@@ -122,6 +131,24 @@ class BaseFilterFunctor {
|
|
122
131
|
virtual ~BaseFilterFunctor() {};
|
123
132
|
};
|
124
133
|
|
134
|
+
template<typename dist_t>
|
135
|
+
class BaseSearchStopCondition {
|
136
|
+
public:
|
137
|
+
virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0;
|
138
|
+
|
139
|
+
virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0;
|
140
|
+
|
141
|
+
virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0;
|
142
|
+
|
143
|
+
virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0;
|
144
|
+
|
145
|
+
virtual bool should_remove_extra() = 0;
|
146
|
+
|
147
|
+
virtual void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) = 0;
|
148
|
+
|
149
|
+
virtual ~BaseSearchStopCondition() {}
|
150
|
+
};
|
151
|
+
|
125
152
|
template <typename T>
|
126
153
|
class pairGreater {
|
127
154
|
public:
|
@@ -196,5 +223,6 @@ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t
|
|
196
223
|
|
197
224
|
#include "space_l2.h"
|
198
225
|
#include "space_ip.h"
|
226
|
+
#include "stop_condition.h"
|
199
227
|
#include "bruteforce.h"
|
200
228
|
#include "hnswalg.h"
|
data/ext/hnswlib/src/space_ip.h
CHANGED
@@ -157,19 +157,44 @@ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void
|
|
157
157
|
|
158
158
|
__m512 sum512 = _mm512_set1_ps(0);
|
159
159
|
|
160
|
-
|
161
|
-
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
|
160
|
+
size_t loop = qty16 / 4;
|
162
161
|
|
162
|
+
while (loop--) {
|
163
163
|
__m512 v1 = _mm512_loadu_ps(pVect1);
|
164
|
-
pVect1 += 16;
|
165
164
|
__m512 v2 = _mm512_loadu_ps(pVect2);
|
165
|
+
pVect1 += 16;
|
166
|
+
pVect2 += 16;
|
167
|
+
|
168
|
+
__m512 v3 = _mm512_loadu_ps(pVect1);
|
169
|
+
__m512 v4 = _mm512_loadu_ps(pVect2);
|
170
|
+
pVect1 += 16;
|
171
|
+
pVect2 += 16;
|
172
|
+
|
173
|
+
__m512 v5 = _mm512_loadu_ps(pVect1);
|
174
|
+
__m512 v6 = _mm512_loadu_ps(pVect2);
|
175
|
+
pVect1 += 16;
|
166
176
|
pVect2 += 16;
|
167
|
-
|
177
|
+
|
178
|
+
__m512 v7 = _mm512_loadu_ps(pVect1);
|
179
|
+
__m512 v8 = _mm512_loadu_ps(pVect2);
|
180
|
+
pVect1 += 16;
|
181
|
+
pVect2 += 16;
|
182
|
+
|
183
|
+
sum512 = _mm512_fmadd_ps(v1, v2, sum512);
|
184
|
+
sum512 = _mm512_fmadd_ps(v3, v4, sum512);
|
185
|
+
sum512 = _mm512_fmadd_ps(v5, v6, sum512);
|
186
|
+
sum512 = _mm512_fmadd_ps(v7, v8, sum512);
|
168
187
|
}
|
169
188
|
|
170
|
-
|
171
|
-
|
189
|
+
while (pVect1 < pEnd1) {
|
190
|
+
__m512 v1 = _mm512_loadu_ps(pVect1);
|
191
|
+
__m512 v2 = _mm512_loadu_ps(pVect2);
|
192
|
+
pVect1 += 16;
|
193
|
+
pVect2 += 16;
|
194
|
+
sum512 = _mm512_fmadd_ps(v1, v2, sum512);
|
195
|
+
}
|
172
196
|
|
197
|
+
float sum = _mm512_reduce_add_ps(sum512);
|
173
198
|
return sum;
|
174
199
|
}
|
175
200
|
|
@@ -0,0 +1,276 @@
|
|
1
|
+
#pragma once
|
2
|
+
#include "space_l2.h"
|
3
|
+
#include "space_ip.h"
|
4
|
+
#include <assert.h>
|
5
|
+
#include <unordered_map>
|
6
|
+
|
7
|
+
namespace hnswlib {
|
8
|
+
|
9
|
+
template<typename DOCIDTYPE>
|
10
|
+
class BaseMultiVectorSpace : public SpaceInterface<float> {
|
11
|
+
public:
|
12
|
+
virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0;
|
13
|
+
|
14
|
+
virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0;
|
15
|
+
};
|
16
|
+
|
17
|
+
|
18
|
+
template<typename DOCIDTYPE>
|
19
|
+
class MultiVectorL2Space : public BaseMultiVectorSpace<DOCIDTYPE> {
|
20
|
+
DISTFUNC<float> fstdistfunc_;
|
21
|
+
size_t data_size_;
|
22
|
+
size_t vector_size_;
|
23
|
+
size_t dim_;
|
24
|
+
|
25
|
+
public:
|
26
|
+
MultiVectorL2Space(size_t dim) {
|
27
|
+
fstdistfunc_ = L2Sqr;
|
28
|
+
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
|
29
|
+
#if defined(USE_AVX512)
|
30
|
+
if (AVX512Capable())
|
31
|
+
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
|
32
|
+
else if (AVXCapable())
|
33
|
+
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
|
34
|
+
#elif defined(USE_AVX)
|
35
|
+
if (AVXCapable())
|
36
|
+
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
|
37
|
+
#endif
|
38
|
+
|
39
|
+
if (dim % 16 == 0)
|
40
|
+
fstdistfunc_ = L2SqrSIMD16Ext;
|
41
|
+
else if (dim % 4 == 0)
|
42
|
+
fstdistfunc_ = L2SqrSIMD4Ext;
|
43
|
+
else if (dim > 16)
|
44
|
+
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
|
45
|
+
else if (dim > 4)
|
46
|
+
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
|
47
|
+
#endif
|
48
|
+
dim_ = dim;
|
49
|
+
vector_size_ = dim * sizeof(float);
|
50
|
+
data_size_ = vector_size_ + sizeof(DOCIDTYPE);
|
51
|
+
}
|
52
|
+
|
53
|
+
size_t get_data_size() override {
|
54
|
+
return data_size_;
|
55
|
+
}
|
56
|
+
|
57
|
+
DISTFUNC<float> get_dist_func() override {
|
58
|
+
return fstdistfunc_;
|
59
|
+
}
|
60
|
+
|
61
|
+
void *get_dist_func_param() override {
|
62
|
+
return &dim_;
|
63
|
+
}
|
64
|
+
|
65
|
+
DOCIDTYPE get_doc_id(const void *datapoint) override {
|
66
|
+
return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
|
67
|
+
}
|
68
|
+
|
69
|
+
void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
|
70
|
+
*(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
|
71
|
+
}
|
72
|
+
|
73
|
+
~MultiVectorL2Space() {}
|
74
|
+
};
|
75
|
+
|
76
|
+
|
77
|
+
template<typename DOCIDTYPE>
|
78
|
+
class MultiVectorInnerProductSpace : public BaseMultiVectorSpace<DOCIDTYPE> {
|
79
|
+
DISTFUNC<float> fstdistfunc_;
|
80
|
+
size_t data_size_;
|
81
|
+
size_t vector_size_;
|
82
|
+
size_t dim_;
|
83
|
+
|
84
|
+
public:
|
85
|
+
MultiVectorInnerProductSpace(size_t dim) {
|
86
|
+
fstdistfunc_ = InnerProductDistance;
|
87
|
+
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
|
88
|
+
#if defined(USE_AVX512)
|
89
|
+
if (AVX512Capable()) {
|
90
|
+
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
|
91
|
+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
|
92
|
+
} else if (AVXCapable()) {
|
93
|
+
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
|
94
|
+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
|
95
|
+
}
|
96
|
+
#elif defined(USE_AVX)
|
97
|
+
if (AVXCapable()) {
|
98
|
+
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
|
99
|
+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
|
100
|
+
}
|
101
|
+
#endif
|
102
|
+
#if defined(USE_AVX)
|
103
|
+
if (AVXCapable()) {
|
104
|
+
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
|
105
|
+
InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
|
106
|
+
}
|
107
|
+
#endif
|
108
|
+
|
109
|
+
if (dim % 16 == 0)
|
110
|
+
fstdistfunc_ = InnerProductDistanceSIMD16Ext;
|
111
|
+
else if (dim % 4 == 0)
|
112
|
+
fstdistfunc_ = InnerProductDistanceSIMD4Ext;
|
113
|
+
else if (dim > 16)
|
114
|
+
fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
|
115
|
+
else if (dim > 4)
|
116
|
+
fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
|
117
|
+
#endif
|
118
|
+
vector_size_ = dim * sizeof(float);
|
119
|
+
data_size_ = vector_size_ + sizeof(DOCIDTYPE);
|
120
|
+
}
|
121
|
+
|
122
|
+
size_t get_data_size() override {
|
123
|
+
return data_size_;
|
124
|
+
}
|
125
|
+
|
126
|
+
DISTFUNC<float> get_dist_func() override {
|
127
|
+
return fstdistfunc_;
|
128
|
+
}
|
129
|
+
|
130
|
+
void *get_dist_func_param() override {
|
131
|
+
return &dim_;
|
132
|
+
}
|
133
|
+
|
134
|
+
DOCIDTYPE get_doc_id(const void *datapoint) override {
|
135
|
+
return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
|
136
|
+
}
|
137
|
+
|
138
|
+
void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
|
139
|
+
*(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
|
140
|
+
}
|
141
|
+
|
142
|
+
~MultiVectorInnerProductSpace() {}
|
143
|
+
};
|
144
|
+
|
145
|
+
|
146
|
+
template<typename DOCIDTYPE, typename dist_t>
|
147
|
+
class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
|
148
|
+
size_t curr_num_docs_;
|
149
|
+
size_t num_docs_to_search_;
|
150
|
+
size_t ef_collection_;
|
151
|
+
std::unordered_map<DOCIDTYPE, size_t> doc_counter_;
|
152
|
+
std::priority_queue<std::pair<dist_t, DOCIDTYPE>> search_results_;
|
153
|
+
BaseMultiVectorSpace<DOCIDTYPE>& space_;
|
154
|
+
|
155
|
+
public:
|
156
|
+
MultiVectorSearchStopCondition(
|
157
|
+
BaseMultiVectorSpace<DOCIDTYPE>& space,
|
158
|
+
size_t num_docs_to_search,
|
159
|
+
size_t ef_collection = 10)
|
160
|
+
: space_(space) {
|
161
|
+
curr_num_docs_ = 0;
|
162
|
+
num_docs_to_search_ = num_docs_to_search;
|
163
|
+
ef_collection_ = std::max(ef_collection, num_docs_to_search);
|
164
|
+
}
|
165
|
+
|
166
|
+
void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
|
167
|
+
DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
|
168
|
+
if (doc_counter_[doc_id] == 0) {
|
169
|
+
curr_num_docs_ += 1;
|
170
|
+
}
|
171
|
+
search_results_.emplace(dist, doc_id);
|
172
|
+
doc_counter_[doc_id] += 1;
|
173
|
+
}
|
174
|
+
|
175
|
+
void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
|
176
|
+
DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
|
177
|
+
doc_counter_[doc_id] -= 1;
|
178
|
+
if (doc_counter_[doc_id] == 0) {
|
179
|
+
curr_num_docs_ -= 1;
|
180
|
+
}
|
181
|
+
search_results_.pop();
|
182
|
+
}
|
183
|
+
|
184
|
+
bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
|
185
|
+
bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_;
|
186
|
+
return stop_search;
|
187
|
+
}
|
188
|
+
|
189
|
+
bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
|
190
|
+
bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist;
|
191
|
+
return flag_consider_candidate;
|
192
|
+
}
|
193
|
+
|
194
|
+
bool should_remove_extra() override {
|
195
|
+
bool flag_remove_extra = curr_num_docs_ > ef_collection_;
|
196
|
+
return flag_remove_extra;
|
197
|
+
}
|
198
|
+
|
199
|
+
void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
|
200
|
+
while (curr_num_docs_ > num_docs_to_search_) {
|
201
|
+
dist_t dist_cand = candidates.back().first;
|
202
|
+
dist_t dist_res = search_results_.top().first;
|
203
|
+
assert(dist_cand == dist_res);
|
204
|
+
DOCIDTYPE doc_id = search_results_.top().second;
|
205
|
+
doc_counter_[doc_id] -= 1;
|
206
|
+
if (doc_counter_[doc_id] == 0) {
|
207
|
+
curr_num_docs_ -= 1;
|
208
|
+
}
|
209
|
+
search_results_.pop();
|
210
|
+
candidates.pop_back();
|
211
|
+
}
|
212
|
+
}
|
213
|
+
|
214
|
+
~MultiVectorSearchStopCondition() {}
|
215
|
+
};
|
216
|
+
|
217
|
+
|
218
|
+
template<typename dist_t>
|
219
|
+
class EpsilonSearchStopCondition : public BaseSearchStopCondition<dist_t> {
|
220
|
+
float epsilon_;
|
221
|
+
size_t min_num_candidates_;
|
222
|
+
size_t max_num_candidates_;
|
223
|
+
size_t curr_num_items_;
|
224
|
+
|
225
|
+
public:
|
226
|
+
EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) {
|
227
|
+
assert(min_num_candidates <= max_num_candidates);
|
228
|
+
epsilon_ = epsilon;
|
229
|
+
min_num_candidates_ = min_num_candidates;
|
230
|
+
max_num_candidates_ = max_num_candidates;
|
231
|
+
curr_num_items_ = 0;
|
232
|
+
}
|
233
|
+
|
234
|
+
void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
|
235
|
+
curr_num_items_ += 1;
|
236
|
+
}
|
237
|
+
|
238
|
+
void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
|
239
|
+
curr_num_items_ -= 1;
|
240
|
+
}
|
241
|
+
|
242
|
+
bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
|
243
|
+
if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) {
|
244
|
+
// new candidate can't improve found results
|
245
|
+
return true;
|
246
|
+
}
|
247
|
+
if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) {
|
248
|
+
// new candidate is out of epsilon region and
|
249
|
+
// minimum number of candidates is checked
|
250
|
+
return true;
|
251
|
+
}
|
252
|
+
return false;
|
253
|
+
}
|
254
|
+
|
255
|
+
bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
|
256
|
+
bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist;
|
257
|
+
return flag_consider_candidate;
|
258
|
+
}
|
259
|
+
|
260
|
+
bool should_remove_extra() {
|
261
|
+
bool flag_remove_extra = curr_num_items_ > max_num_candidates_;
|
262
|
+
return flag_remove_extra;
|
263
|
+
}
|
264
|
+
|
265
|
+
void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
|
266
|
+
while (!candidates.empty() && candidates.back().first > epsilon_) {
|
267
|
+
candidates.pop_back();
|
268
|
+
}
|
269
|
+
while (candidates.size() > max_num_candidates_) {
|
270
|
+
candidates.pop_back();
|
271
|
+
}
|
272
|
+
}
|
273
|
+
|
274
|
+
~EpsilonSearchStopCondition() {}
|
275
|
+
};
|
276
|
+
} // namespace hnswlib
|
data/lib/hnswlib/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# Hnswlib.rb provides Ruby bindings for the Hnswlib.
|
4
4
|
module Hnswlib
|
5
5
|
# The version of Hnswlib.rb you install.
|
6
|
-
VERSION = '0.
|
6
|
+
VERSION = '0.9.0'
|
7
7
|
|
8
8
|
# The version of Hnswlib included with gem.
|
9
|
-
HSWLIB_VERSION = '0.
|
9
|
+
HSWLIB_VERSION = '0.8.0'
|
10
10
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: hnswlib
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-12-16 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: Hnswlib.rb provides Ruby bindings for the Hnswlib.
|
14
14
|
email:
|
@@ -30,6 +30,7 @@ files:
|
|
30
30
|
- ext/hnswlib/src/hnswlib.h
|
31
31
|
- ext/hnswlib/src/space_ip.h
|
32
32
|
- ext/hnswlib/src/space_l2.h
|
33
|
+
- ext/hnswlib/src/stop_condition.h
|
33
34
|
- ext/hnswlib/src/visited_list_pool.h
|
34
35
|
- lib/hnswlib.rb
|
35
36
|
- lib/hnswlib/version.rb
|
@@ -41,6 +42,7 @@ metadata:
|
|
41
42
|
homepage_uri: https://github.com/yoshoku/hnswlib.rb
|
42
43
|
source_code_uri: https://github.com/yoshoku/hnswlib.rb
|
43
44
|
changelog_uri: https://github.com/yoshoku/hnswlib.rb/blob/main/CHANGELOG.md
|
45
|
+
documentation_uri: https://yoshoku.github.io/hnswlib.rb/doc/
|
44
46
|
rubygems_mfa_required: 'true'
|
45
47
|
post_install_message:
|
46
48
|
rdoc_options: []
|
@@ -57,7 +59,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
57
59
|
- !ruby/object:Gem::Version
|
58
60
|
version: '0'
|
59
61
|
requirements: []
|
60
|
-
rubygems_version: 3.
|
62
|
+
rubygems_version: 3.4.22
|
61
63
|
signing_key:
|
62
64
|
specification_version: 4
|
63
65
|
summary: Ruby bindings for the Hnswlib.
|