hnswlib 0.8.1 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/hnswlib/hnswlibext.hpp +0 -4
- data/ext/hnswlib/src/bruteforce.h +10 -4
- data/ext/hnswlib/src/hnswalg.h +166 -25
- data/ext/hnswlib/src/hnswlib.h +29 -1
- data/ext/hnswlib/src/space_ip.h +31 -6
- data/ext/hnswlib/src/stop_condition.h +276 -0
- data/lib/hnswlib/version.rb +2 -2
- metadata +5 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 39458b40736b4a330a0c1769a3adcf37d9d40614b1270d33885932244e943b74
|
4
|
+
data.tar.gz: d49bf8e158c55235fdeca0656cd3ddab6a07e5abc1468eb8917b21f9baf5d89c
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ec8560f2b7aa9bb5df098b9c863e4602465a42e8ea27a19be55b9d62e9ba0d32fa70ccbcb13323b771370df5ded9370f9f91d17c8d63aecc18d2b45e29df0149
|
7
|
+
data.tar.gz: c9b434fb313f41e2f0570070944eb6eb3b4432a65fce2c63ca6592ab1de30f9b04834d87e1bd40db465dceea0eaac2279378eda537f6b7d51c2bdb0aa06e1ab6
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,8 @@
|
|
1
|
+
## [[0.9.0](https://github.com/yoshoku/hnswlib.rb/compare/v0.8.1...v0.9.0)] - 2023-12-16
|
2
|
+
|
3
|
+
- Update bundled hnswlib version to 0.8.0.
|
4
|
+
- Multi-vector document search and epsilon search, which are added only to the C++ version, are not supported. These features will be supported in future release.
|
5
|
+
|
1
6
|
## [0.8.1] - 2023-03-18
|
2
7
|
|
3
8
|
- Update the type declarations of HierarchicalNSW and BruteforceSearch along with recent changes.
|
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -528,10 +528,6 @@ private:
|
|
528
528
|
free(index->linkLists_);
|
529
529
|
index->linkLists_ = nullptr;
|
530
530
|
}
|
531
|
-
if (index->visited_list_pool_) {
|
532
|
-
delete index->visited_list_pool_;
|
533
|
-
index->visited_list_pool_ = nullptr;
|
534
|
-
}
|
535
531
|
|
536
532
|
try {
|
537
533
|
index->loadIndex(filename, space);
|
@@ -85,10 +85,16 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
|
85
85
|
|
86
86
|
|
87
87
|
void removePoint(labeltype cur_external) {
|
88
|
-
|
88
|
+
std::unique_lock<std::mutex> lock(index_lock);
|
89
89
|
|
90
|
-
dict_external_to_internal.
|
90
|
+
auto found = dict_external_to_internal.find(cur_external);
|
91
|
+
if (found == dict_external_to_internal.end()) {
|
92
|
+
return;
|
93
|
+
}
|
94
|
+
|
95
|
+
dict_external_to_internal.erase(found);
|
91
96
|
|
97
|
+
size_t cur_c = found->second;
|
92
98
|
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
93
99
|
dict_external_to_internal[label] = cur_c;
|
94
100
|
memcpy(data_ + size_per_element_ * cur_c,
|
@@ -107,7 +113,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
|
107
113
|
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
108
114
|
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
|
109
115
|
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
110
|
-
topResults.
|
116
|
+
topResults.emplace(dist, label);
|
111
117
|
}
|
112
118
|
}
|
113
119
|
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
|
@@ -116,7 +122,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
|
116
122
|
if (dist <= lastdist) {
|
117
123
|
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
|
118
124
|
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
119
|
-
topResults.
|
125
|
+
topResults.emplace(dist, label);
|
120
126
|
}
|
121
127
|
if (topResults.size() > k)
|
122
128
|
topResults.pop();
|
data/ext/hnswlib/src/hnswalg.h
CHANGED
@@ -8,6 +8,7 @@
|
|
8
8
|
#include <assert.h>
|
9
9
|
#include <unordered_set>
|
10
10
|
#include <list>
|
11
|
+
#include <memory>
|
11
12
|
|
12
13
|
namespace hnswlib {
|
13
14
|
typedef unsigned int tableint;
|
@@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
33
34
|
double mult_{0.0}, revSize_{0.0};
|
34
35
|
int maxlevel_{0};
|
35
36
|
|
36
|
-
VisitedListPool
|
37
|
+
std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};
|
37
38
|
|
38
39
|
// Locks operations with element by label value
|
39
40
|
mutable std::vector<std::mutex> label_op_locks_;
|
@@ -93,8 +94,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
93
94
|
size_t ef_construction = 200,
|
94
95
|
size_t random_seed = 100,
|
95
96
|
bool allow_replace_deleted = false)
|
96
|
-
:
|
97
|
-
|
97
|
+
: label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
|
98
|
+
link_list_locks_(max_elements),
|
98
99
|
element_levels_(max_elements),
|
99
100
|
allow_replace_deleted_(allow_replace_deleted) {
|
100
101
|
max_elements_ = max_elements;
|
@@ -102,7 +103,13 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
102
103
|
data_size_ = s->get_data_size();
|
103
104
|
fstdistfunc_ = s->get_dist_func();
|
104
105
|
dist_func_param_ = s->get_dist_func_param();
|
105
|
-
|
106
|
+
if ( M <= 10000 ) {
|
107
|
+
M_ = M;
|
108
|
+
} else {
|
109
|
+
HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
|
110
|
+
HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
|
111
|
+
M_ = 10000;
|
112
|
+
}
|
106
113
|
maxM_ = M_;
|
107
114
|
maxM0_ = M_ * 2;
|
108
115
|
ef_construction_ = std::max(ef_construction, M_);
|
@@ -123,7 +130,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
123
130
|
|
124
131
|
cur_element_count = 0;
|
125
132
|
|
126
|
-
visited_list_pool_ = new VisitedListPool(1, max_elements);
|
133
|
+
visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));
|
127
134
|
|
128
135
|
// initializations for special treatment of the first node
|
129
136
|
enterpoint_node_ = -1;
|
@@ -139,13 +146,20 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
139
146
|
|
140
147
|
|
141
148
|
~HierarchicalNSW() {
|
149
|
+
clear();
|
150
|
+
}
|
151
|
+
|
152
|
+
void clear() {
|
142
153
|
free(data_level0_memory_);
|
154
|
+
data_level0_memory_ = nullptr;
|
143
155
|
for (tableint i = 0; i < cur_element_count; i++) {
|
144
156
|
if (element_levels_[i] > 0)
|
145
157
|
free(linkLists_[i]);
|
146
158
|
}
|
147
159
|
free(linkLists_);
|
148
|
-
|
160
|
+
linkLists_ = nullptr;
|
161
|
+
cur_element_count = 0;
|
162
|
+
visited_list_pool_.reset(nullptr);
|
149
163
|
}
|
150
164
|
|
151
165
|
|
@@ -292,9 +306,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
292
306
|
}
|
293
307
|
|
294
308
|
|
295
|
-
|
309
|
+
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
|
310
|
+
template <bool bare_bone_search = true, bool collect_metrics = false>
|
296
311
|
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
|
297
|
-
searchBaseLayerST(
|
312
|
+
searchBaseLayerST(
|
313
|
+
tableint ep_id,
|
314
|
+
const void *data_point,
|
315
|
+
size_t ef,
|
316
|
+
BaseFilterFunctor* isIdAllowed = nullptr,
|
317
|
+
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
|
298
318
|
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
|
299
319
|
vl_type *visited_array = vl->mass;
|
300
320
|
vl_type visited_array_tag = vl->curV;
|
@@ -303,10 +323,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
303
323
|
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
|
304
324
|
|
305
325
|
dist_t lowerBound;
|
306
|
-
if (
|
307
|
-
|
326
|
+
if (bare_bone_search ||
|
327
|
+
(!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
|
328
|
+
char* ep_data = getDataByInternalId(ep_id);
|
329
|
+
dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
|
308
330
|
lowerBound = dist;
|
309
331
|
top_candidates.emplace(dist, ep_id);
|
332
|
+
if (!bare_bone_search && stop_condition) {
|
333
|
+
stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
|
334
|
+
}
|
310
335
|
candidate_set.emplace(-dist, ep_id);
|
311
336
|
} else {
|
312
337
|
lowerBound = std::numeric_limits<dist_t>::max();
|
@@ -317,9 +342,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
317
342
|
|
318
343
|
while (!candidate_set.empty()) {
|
319
344
|
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
|
345
|
+
dist_t candidate_dist = -current_node_pair.first;
|
320
346
|
|
321
|
-
|
322
|
-
|
347
|
+
bool flag_stop_search;
|
348
|
+
if (bare_bone_search) {
|
349
|
+
flag_stop_search = candidate_dist > lowerBound;
|
350
|
+
} else {
|
351
|
+
if (stop_condition) {
|
352
|
+
flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
|
353
|
+
} else {
|
354
|
+
flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
|
355
|
+
}
|
356
|
+
}
|
357
|
+
if (flag_stop_search) {
|
323
358
|
break;
|
324
359
|
}
|
325
360
|
candidate_set.pop();
|
@@ -354,7 +389,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
354
389
|
char *currObj1 = (getDataByInternalId(candidate_id));
|
355
390
|
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
|
356
391
|
|
357
|
-
|
392
|
+
bool flag_consider_candidate;
|
393
|
+
if (!bare_bone_search && stop_condition) {
|
394
|
+
flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
|
395
|
+
} else {
|
396
|
+
flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
|
397
|
+
}
|
398
|
+
|
399
|
+
if (flag_consider_candidate) {
|
358
400
|
candidate_set.emplace(-dist, candidate_id);
|
359
401
|
#ifdef USE_SSE
|
360
402
|
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
|
@@ -362,11 +404,30 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
362
404
|
_MM_HINT_T0); ////////////////////////
|
363
405
|
#endif
|
364
406
|
|
365
|
-
if (
|
407
|
+
if (bare_bone_search ||
|
408
|
+
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
|
366
409
|
top_candidates.emplace(dist, candidate_id);
|
410
|
+
if (!bare_bone_search && stop_condition) {
|
411
|
+
stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
|
412
|
+
}
|
413
|
+
}
|
367
414
|
|
368
|
-
|
415
|
+
bool flag_remove_extra = false;
|
416
|
+
if (!bare_bone_search && stop_condition) {
|
417
|
+
flag_remove_extra = stop_condition->should_remove_extra();
|
418
|
+
} else {
|
419
|
+
flag_remove_extra = top_candidates.size() > ef;
|
420
|
+
}
|
421
|
+
while (flag_remove_extra) {
|
422
|
+
tableint id = top_candidates.top().second;
|
369
423
|
top_candidates.pop();
|
424
|
+
if (!bare_bone_search && stop_condition) {
|
425
|
+
stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
|
426
|
+
flag_remove_extra = stop_condition->should_remove_extra();
|
427
|
+
} else {
|
428
|
+
flag_remove_extra = top_candidates.size() > ef;
|
429
|
+
}
|
430
|
+
}
|
370
431
|
|
371
432
|
if (!top_candidates.empty())
|
372
433
|
lowerBound = top_candidates.top().first;
|
@@ -381,8 +442,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
381
442
|
|
382
443
|
|
383
444
|
void getNeighborsByHeuristic2(
|
384
|
-
|
385
|
-
|
445
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
|
446
|
+
const size_t M) {
|
386
447
|
if (top_candidates.size() < M) {
|
387
448
|
return;
|
388
449
|
}
|
@@ -574,8 +635,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
574
635
|
if (new_max_elements < cur_element_count)
|
575
636
|
throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
|
576
637
|
|
577
|
-
|
578
|
-
visited_list_pool_ = new VisitedListPool(1, new_max_elements);
|
638
|
+
visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));
|
579
639
|
|
580
640
|
element_levels_.resize(new_max_elements);
|
581
641
|
|
@@ -596,6 +656,32 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
596
656
|
max_elements_ = new_max_elements;
|
597
657
|
}
|
598
658
|
|
659
|
+
size_t indexFileSize() const {
|
660
|
+
size_t size = 0;
|
661
|
+
size += sizeof(offsetLevel0_);
|
662
|
+
size += sizeof(max_elements_);
|
663
|
+
size += sizeof(cur_element_count);
|
664
|
+
size += sizeof(size_data_per_element_);
|
665
|
+
size += sizeof(label_offset_);
|
666
|
+
size += sizeof(offsetData_);
|
667
|
+
size += sizeof(maxlevel_);
|
668
|
+
size += sizeof(enterpoint_node_);
|
669
|
+
size += sizeof(maxM_);
|
670
|
+
|
671
|
+
size += sizeof(maxM0_);
|
672
|
+
size += sizeof(M_);
|
673
|
+
size += sizeof(mult_);
|
674
|
+
size += sizeof(ef_construction_);
|
675
|
+
|
676
|
+
size += cur_element_count * size_data_per_element_;
|
677
|
+
|
678
|
+
for (size_t i = 0; i < cur_element_count; i++) {
|
679
|
+
unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
|
680
|
+
size += sizeof(linkListSize);
|
681
|
+
size += linkListSize;
|
682
|
+
}
|
683
|
+
return size;
|
684
|
+
}
|
599
685
|
|
600
686
|
void saveIndex(const std::string &location) {
|
601
687
|
std::ofstream output(location, std::ios::binary);
|
@@ -634,6 +720,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
634
720
|
if (!input.is_open())
|
635
721
|
throw std::runtime_error("Cannot open file");
|
636
722
|
|
723
|
+
clear();
|
637
724
|
// get file size:
|
638
725
|
input.seekg(0, input.end);
|
639
726
|
std::streampos total_filesize = input.tellg();
|
@@ -699,7 +786,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
699
786
|
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
|
700
787
|
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
|
701
788
|
|
702
|
-
visited_list_pool_
|
789
|
+
visited_list_pool_.reset(new VisitedListPool(1, max_elements));
|
703
790
|
|
704
791
|
linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
|
705
792
|
if (linkLists_ == nullptr)
|
@@ -753,7 +840,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
753
840
|
size_t dim = *((size_t *) dist_func_param_);
|
754
841
|
std::vector<data_t> data;
|
755
842
|
data_t* data_ptr = (data_t*) data_ptrv;
|
756
|
-
for (
|
843
|
+
for (size_t i = 0; i < dim; i++) {
|
757
844
|
data.push_back(*data_ptr);
|
758
845
|
data_ptr += 1;
|
759
846
|
}
|
@@ -1217,11 +1304,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
1217
1304
|
}
|
1218
1305
|
|
1219
1306
|
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
|
1220
|
-
|
1221
|
-
|
1307
|
+
bool bare_bone_search = !num_deleted_ && !isIdAllowed;
|
1308
|
+
if (bare_bone_search) {
|
1309
|
+
top_candidates = searchBaseLayerST<true>(
|
1222
1310
|
currObj, query_data, std::max(ef_, k), isIdAllowed);
|
1223
1311
|
} else {
|
1224
|
-
top_candidates = searchBaseLayerST<false
|
1312
|
+
top_candidates = searchBaseLayerST<false>(
|
1225
1313
|
currObj, query_data, std::max(ef_, k), isIdAllowed);
|
1226
1314
|
}
|
1227
1315
|
|
@@ -1237,6 +1325,60 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
1237
1325
|
}
|
1238
1326
|
|
1239
1327
|
|
1328
|
+
std::vector<std::pair<dist_t, labeltype >>
|
1329
|
+
searchStopConditionClosest(
|
1330
|
+
const void *query_data,
|
1331
|
+
BaseSearchStopCondition<dist_t>& stop_condition,
|
1332
|
+
BaseFilterFunctor* isIdAllowed = nullptr) const {
|
1333
|
+
std::vector<std::pair<dist_t, labeltype >> result;
|
1334
|
+
if (cur_element_count == 0) return result;
|
1335
|
+
|
1336
|
+
tableint currObj = enterpoint_node_;
|
1337
|
+
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
|
1338
|
+
|
1339
|
+
for (int level = maxlevel_; level > 0; level--) {
|
1340
|
+
bool changed = true;
|
1341
|
+
while (changed) {
|
1342
|
+
changed = false;
|
1343
|
+
unsigned int *data;
|
1344
|
+
|
1345
|
+
data = (unsigned int *) get_linklist(currObj, level);
|
1346
|
+
int size = getListCount(data);
|
1347
|
+
metric_hops++;
|
1348
|
+
metric_distance_computations+=size;
|
1349
|
+
|
1350
|
+
tableint *datal = (tableint *) (data + 1);
|
1351
|
+
for (int i = 0; i < size; i++) {
|
1352
|
+
tableint cand = datal[i];
|
1353
|
+
if (cand < 0 || cand > max_elements_)
|
1354
|
+
throw std::runtime_error("cand error");
|
1355
|
+
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
|
1356
|
+
|
1357
|
+
if (d < curdist) {
|
1358
|
+
curdist = d;
|
1359
|
+
currObj = cand;
|
1360
|
+
changed = true;
|
1361
|
+
}
|
1362
|
+
}
|
1363
|
+
}
|
1364
|
+
}
|
1365
|
+
|
1366
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
|
1367
|
+
top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition);
|
1368
|
+
|
1369
|
+
size_t sz = top_candidates.size();
|
1370
|
+
result.resize(sz);
|
1371
|
+
while (!top_candidates.empty()) {
|
1372
|
+
result[--sz] = top_candidates.top();
|
1373
|
+
top_candidates.pop();
|
1374
|
+
}
|
1375
|
+
|
1376
|
+
stop_condition.filter_results(result);
|
1377
|
+
|
1378
|
+
return result;
|
1379
|
+
}
|
1380
|
+
|
1381
|
+
|
1240
1382
|
void checkIntegrity() {
|
1241
1383
|
int connections_checked = 0;
|
1242
1384
|
std::vector <int > inbound_connections_num(cur_element_count, 0);
|
@@ -1247,7 +1389,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
|
1247
1389
|
tableint *data = (tableint *) (ll_cur + 1);
|
1248
1390
|
std::unordered_set<tableint> s;
|
1249
1391
|
for (int j = 0; j < size; j++) {
|
1250
|
-
assert(data[j] > 0);
|
1251
1392
|
assert(data[j] < cur_element_count);
|
1252
1393
|
assert(data[j] != i);
|
1253
1394
|
inbound_connections_num[data[j]]++;
|
data/ext/hnswlib/src/hnswlib.h
CHANGED
@@ -1,4 +1,13 @@
|
|
1
1
|
#pragma once
|
2
|
+
|
3
|
+
// https://github.com/nmslib/hnswlib/pull/508
|
4
|
+
// This allows others to provide their own error stream (e.g. RcppHNSW)
|
5
|
+
#ifndef HNSWLIB_ERR_OVERRIDE
|
6
|
+
#define HNSWERR std::cerr
|
7
|
+
#else
|
8
|
+
#define HNSWERR HNSWLIB_ERR_OVERRIDE
|
9
|
+
#endif
|
10
|
+
|
2
11
|
#ifndef NO_MANUAL_VECTORIZATION
|
3
12
|
#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
|
4
13
|
#define USE_SSE
|
@@ -15,7 +24,7 @@
|
|
15
24
|
#ifdef _MSC_VER
|
16
25
|
#include <intrin.h>
|
17
26
|
#include <stdexcept>
|
18
|
-
void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
|
27
|
+
static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
|
19
28
|
__cpuidex(out, eax, ecx);
|
20
29
|
}
|
21
30
|
static __int64 xgetbv(unsigned int x) {
|
@@ -122,6 +131,24 @@ class BaseFilterFunctor {
|
|
122
131
|
virtual ~BaseFilterFunctor() {};
|
123
132
|
};
|
124
133
|
|
134
|
+
template<typename dist_t>
|
135
|
+
class BaseSearchStopCondition {
|
136
|
+
public:
|
137
|
+
virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0;
|
138
|
+
|
139
|
+
virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0;
|
140
|
+
|
141
|
+
virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0;
|
142
|
+
|
143
|
+
virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0;
|
144
|
+
|
145
|
+
virtual bool should_remove_extra() = 0;
|
146
|
+
|
147
|
+
virtual void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) = 0;
|
148
|
+
|
149
|
+
virtual ~BaseSearchStopCondition() {}
|
150
|
+
};
|
151
|
+
|
125
152
|
template <typename T>
|
126
153
|
class pairGreater {
|
127
154
|
public:
|
@@ -196,5 +223,6 @@ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t
|
|
196
223
|
|
197
224
|
#include "space_l2.h"
|
198
225
|
#include "space_ip.h"
|
226
|
+
#include "stop_condition.h"
|
199
227
|
#include "bruteforce.h"
|
200
228
|
#include "hnswalg.h"
|
data/ext/hnswlib/src/space_ip.h
CHANGED
@@ -157,19 +157,44 @@ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void
|
|
157
157
|
|
158
158
|
__m512 sum512 = _mm512_set1_ps(0);
|
159
159
|
|
160
|
-
|
161
|
-
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
|
160
|
+
size_t loop = qty16 / 4;
|
162
161
|
|
162
|
+
while (loop--) {
|
163
163
|
__m512 v1 = _mm512_loadu_ps(pVect1);
|
164
|
-
pVect1 += 16;
|
165
164
|
__m512 v2 = _mm512_loadu_ps(pVect2);
|
165
|
+
pVect1 += 16;
|
166
|
+
pVect2 += 16;
|
167
|
+
|
168
|
+
__m512 v3 = _mm512_loadu_ps(pVect1);
|
169
|
+
__m512 v4 = _mm512_loadu_ps(pVect2);
|
170
|
+
pVect1 += 16;
|
171
|
+
pVect2 += 16;
|
172
|
+
|
173
|
+
__m512 v5 = _mm512_loadu_ps(pVect1);
|
174
|
+
__m512 v6 = _mm512_loadu_ps(pVect2);
|
175
|
+
pVect1 += 16;
|
166
176
|
pVect2 += 16;
|
167
|
-
|
177
|
+
|
178
|
+
__m512 v7 = _mm512_loadu_ps(pVect1);
|
179
|
+
__m512 v8 = _mm512_loadu_ps(pVect2);
|
180
|
+
pVect1 += 16;
|
181
|
+
pVect2 += 16;
|
182
|
+
|
183
|
+
sum512 = _mm512_fmadd_ps(v1, v2, sum512);
|
184
|
+
sum512 = _mm512_fmadd_ps(v3, v4, sum512);
|
185
|
+
sum512 = _mm512_fmadd_ps(v5, v6, sum512);
|
186
|
+
sum512 = _mm512_fmadd_ps(v7, v8, sum512);
|
168
187
|
}
|
169
188
|
|
170
|
-
|
171
|
-
|
189
|
+
while (pVect1 < pEnd1) {
|
190
|
+
__m512 v1 = _mm512_loadu_ps(pVect1);
|
191
|
+
__m512 v2 = _mm512_loadu_ps(pVect2);
|
192
|
+
pVect1 += 16;
|
193
|
+
pVect2 += 16;
|
194
|
+
sum512 = _mm512_fmadd_ps(v1, v2, sum512);
|
195
|
+
}
|
172
196
|
|
197
|
+
float sum = _mm512_reduce_add_ps(sum512);
|
173
198
|
return sum;
|
174
199
|
}
|
175
200
|
|
@@ -0,0 +1,276 @@
|
|
1
|
+
#pragma once
|
2
|
+
#include "space_l2.h"
|
3
|
+
#include "space_ip.h"
|
4
|
+
#include <assert.h>
|
5
|
+
#include <unordered_map>
|
6
|
+
|
7
|
+
namespace hnswlib {
|
8
|
+
|
9
|
+
template<typename DOCIDTYPE>
|
10
|
+
class BaseMultiVectorSpace : public SpaceInterface<float> {
|
11
|
+
public:
|
12
|
+
virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0;
|
13
|
+
|
14
|
+
virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0;
|
15
|
+
};
|
16
|
+
|
17
|
+
|
18
|
+
template<typename DOCIDTYPE>
|
19
|
+
class MultiVectorL2Space : public BaseMultiVectorSpace<DOCIDTYPE> {
|
20
|
+
DISTFUNC<float> fstdistfunc_;
|
21
|
+
size_t data_size_;
|
22
|
+
size_t vector_size_;
|
23
|
+
size_t dim_;
|
24
|
+
|
25
|
+
public:
|
26
|
+
MultiVectorL2Space(size_t dim) {
|
27
|
+
fstdistfunc_ = L2Sqr;
|
28
|
+
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
|
29
|
+
#if defined(USE_AVX512)
|
30
|
+
if (AVX512Capable())
|
31
|
+
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
|
32
|
+
else if (AVXCapable())
|
33
|
+
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
|
34
|
+
#elif defined(USE_AVX)
|
35
|
+
if (AVXCapable())
|
36
|
+
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
|
37
|
+
#endif
|
38
|
+
|
39
|
+
if (dim % 16 == 0)
|
40
|
+
fstdistfunc_ = L2SqrSIMD16Ext;
|
41
|
+
else if (dim % 4 == 0)
|
42
|
+
fstdistfunc_ = L2SqrSIMD4Ext;
|
43
|
+
else if (dim > 16)
|
44
|
+
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
|
45
|
+
else if (dim > 4)
|
46
|
+
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
|
47
|
+
#endif
|
48
|
+
dim_ = dim;
|
49
|
+
vector_size_ = dim * sizeof(float);
|
50
|
+
data_size_ = vector_size_ + sizeof(DOCIDTYPE);
|
51
|
+
}
|
52
|
+
|
53
|
+
size_t get_data_size() override {
|
54
|
+
return data_size_;
|
55
|
+
}
|
56
|
+
|
57
|
+
DISTFUNC<float> get_dist_func() override {
|
58
|
+
return fstdistfunc_;
|
59
|
+
}
|
60
|
+
|
61
|
+
void *get_dist_func_param() override {
|
62
|
+
return &dim_;
|
63
|
+
}
|
64
|
+
|
65
|
+
DOCIDTYPE get_doc_id(const void *datapoint) override {
|
66
|
+
return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
|
67
|
+
}
|
68
|
+
|
69
|
+
void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
|
70
|
+
*(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
|
71
|
+
}
|
72
|
+
|
73
|
+
~MultiVectorL2Space() {}
|
74
|
+
};
|
75
|
+
|
76
|
+
|
77
|
+
template<typename DOCIDTYPE>
|
78
|
+
class MultiVectorInnerProductSpace : public BaseMultiVectorSpace<DOCIDTYPE> {
|
79
|
+
DISTFUNC<float> fstdistfunc_;
|
80
|
+
size_t data_size_;
|
81
|
+
size_t vector_size_;
|
82
|
+
size_t dim_;
|
83
|
+
|
84
|
+
public:
|
85
|
+
MultiVectorInnerProductSpace(size_t dim) {
|
86
|
+
fstdistfunc_ = InnerProductDistance;
|
87
|
+
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
|
88
|
+
#if defined(USE_AVX512)
|
89
|
+
if (AVX512Capable()) {
|
90
|
+
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
|
91
|
+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
|
92
|
+
} else if (AVXCapable()) {
|
93
|
+
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
|
94
|
+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
|
95
|
+
}
|
96
|
+
#elif defined(USE_AVX)
|
97
|
+
if (AVXCapable()) {
|
98
|
+
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
|
99
|
+
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
|
100
|
+
}
|
101
|
+
#endif
|
102
|
+
#if defined(USE_AVX)
|
103
|
+
if (AVXCapable()) {
|
104
|
+
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
|
105
|
+
InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
|
106
|
+
}
|
107
|
+
#endif
|
108
|
+
|
109
|
+
if (dim % 16 == 0)
|
110
|
+
fstdistfunc_ = InnerProductDistanceSIMD16Ext;
|
111
|
+
else if (dim % 4 == 0)
|
112
|
+
fstdistfunc_ = InnerProductDistanceSIMD4Ext;
|
113
|
+
else if (dim > 16)
|
114
|
+
fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
|
115
|
+
else if (dim > 4)
|
116
|
+
fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
|
117
|
+
#endif
|
118
|
+
vector_size_ = dim * sizeof(float);
|
119
|
+
data_size_ = vector_size_ + sizeof(DOCIDTYPE);
|
120
|
+
}
|
121
|
+
|
122
|
+
size_t get_data_size() override {
|
123
|
+
return data_size_;
|
124
|
+
}
|
125
|
+
|
126
|
+
DISTFUNC<float> get_dist_func() override {
|
127
|
+
return fstdistfunc_;
|
128
|
+
}
|
129
|
+
|
130
|
+
void *get_dist_func_param() override {
|
131
|
+
return &dim_;
|
132
|
+
}
|
133
|
+
|
134
|
+
DOCIDTYPE get_doc_id(const void *datapoint) override {
|
135
|
+
return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
|
136
|
+
}
|
137
|
+
|
138
|
+
void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
|
139
|
+
*(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
|
140
|
+
}
|
141
|
+
|
142
|
+
~MultiVectorInnerProductSpace() {}
|
143
|
+
};
|
144
|
+
|
145
|
+
|
146
|
+
template<typename DOCIDTYPE, typename dist_t>
|
147
|
+
class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
|
148
|
+
size_t curr_num_docs_;
|
149
|
+
size_t num_docs_to_search_;
|
150
|
+
size_t ef_collection_;
|
151
|
+
std::unordered_map<DOCIDTYPE, size_t> doc_counter_;
|
152
|
+
std::priority_queue<std::pair<dist_t, DOCIDTYPE>> search_results_;
|
153
|
+
BaseMultiVectorSpace<DOCIDTYPE>& space_;
|
154
|
+
|
155
|
+
public:
|
156
|
+
MultiVectorSearchStopCondition(
|
157
|
+
BaseMultiVectorSpace<DOCIDTYPE>& space,
|
158
|
+
size_t num_docs_to_search,
|
159
|
+
size_t ef_collection = 10)
|
160
|
+
: space_(space) {
|
161
|
+
curr_num_docs_ = 0;
|
162
|
+
num_docs_to_search_ = num_docs_to_search;
|
163
|
+
ef_collection_ = std::max(ef_collection, num_docs_to_search);
|
164
|
+
}
|
165
|
+
|
166
|
+
void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
|
167
|
+
DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
|
168
|
+
if (doc_counter_[doc_id] == 0) {
|
169
|
+
curr_num_docs_ += 1;
|
170
|
+
}
|
171
|
+
search_results_.emplace(dist, doc_id);
|
172
|
+
doc_counter_[doc_id] += 1;
|
173
|
+
}
|
174
|
+
|
175
|
+
void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
|
176
|
+
DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
|
177
|
+
doc_counter_[doc_id] -= 1;
|
178
|
+
if (doc_counter_[doc_id] == 0) {
|
179
|
+
curr_num_docs_ -= 1;
|
180
|
+
}
|
181
|
+
search_results_.pop();
|
182
|
+
}
|
183
|
+
|
184
|
+
bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
|
185
|
+
bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_;
|
186
|
+
return stop_search;
|
187
|
+
}
|
188
|
+
|
189
|
+
bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
|
190
|
+
bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist;
|
191
|
+
return flag_consider_candidate;
|
192
|
+
}
|
193
|
+
|
194
|
+
bool should_remove_extra() override {
|
195
|
+
bool flag_remove_extra = curr_num_docs_ > ef_collection_;
|
196
|
+
return flag_remove_extra;
|
197
|
+
}
|
198
|
+
|
199
|
+
void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
|
200
|
+
while (curr_num_docs_ > num_docs_to_search_) {
|
201
|
+
dist_t dist_cand = candidates.back().first;
|
202
|
+
dist_t dist_res = search_results_.top().first;
|
203
|
+
assert(dist_cand == dist_res);
|
204
|
+
DOCIDTYPE doc_id = search_results_.top().second;
|
205
|
+
doc_counter_[doc_id] -= 1;
|
206
|
+
if (doc_counter_[doc_id] == 0) {
|
207
|
+
curr_num_docs_ -= 1;
|
208
|
+
}
|
209
|
+
search_results_.pop();
|
210
|
+
candidates.pop_back();
|
211
|
+
}
|
212
|
+
}
|
213
|
+
|
214
|
+
~MultiVectorSearchStopCondition() {}
|
215
|
+
};
|
216
|
+
|
217
|
+
|
218
|
+
template<typename dist_t>
|
219
|
+
class EpsilonSearchStopCondition : public BaseSearchStopCondition<dist_t> {
|
220
|
+
float epsilon_;
|
221
|
+
size_t min_num_candidates_;
|
222
|
+
size_t max_num_candidates_;
|
223
|
+
size_t curr_num_items_;
|
224
|
+
|
225
|
+
public:
|
226
|
+
EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) {
|
227
|
+
assert(min_num_candidates <= max_num_candidates);
|
228
|
+
epsilon_ = epsilon;
|
229
|
+
min_num_candidates_ = min_num_candidates;
|
230
|
+
max_num_candidates_ = max_num_candidates;
|
231
|
+
curr_num_items_ = 0;
|
232
|
+
}
|
233
|
+
|
234
|
+
void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
|
235
|
+
curr_num_items_ += 1;
|
236
|
+
}
|
237
|
+
|
238
|
+
void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
|
239
|
+
curr_num_items_ -= 1;
|
240
|
+
}
|
241
|
+
|
242
|
+
bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
|
243
|
+
if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) {
|
244
|
+
// new candidate can't improve found results
|
245
|
+
return true;
|
246
|
+
}
|
247
|
+
if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) {
|
248
|
+
// new candidate is out of epsilon region and
|
249
|
+
// minimum number of candidates is checked
|
250
|
+
return true;
|
251
|
+
}
|
252
|
+
return false;
|
253
|
+
}
|
254
|
+
|
255
|
+
bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
|
256
|
+
bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist;
|
257
|
+
return flag_consider_candidate;
|
258
|
+
}
|
259
|
+
|
260
|
+
bool should_remove_extra() {
|
261
|
+
bool flag_remove_extra = curr_num_items_ > max_num_candidates_;
|
262
|
+
return flag_remove_extra;
|
263
|
+
}
|
264
|
+
|
265
|
+
void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
|
266
|
+
while (!candidates.empty() && candidates.back().first > epsilon_) {
|
267
|
+
candidates.pop_back();
|
268
|
+
}
|
269
|
+
while (candidates.size() > max_num_candidates_) {
|
270
|
+
candidates.pop_back();
|
271
|
+
}
|
272
|
+
}
|
273
|
+
|
274
|
+
~EpsilonSearchStopCondition() {}
|
275
|
+
};
|
276
|
+
} // namespace hnswlib
|
data/lib/hnswlib/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# Hnswlib.rb provides Ruby bindings for the Hnswlib.
|
4
4
|
module Hnswlib
|
5
5
|
# The version of Hnswlib.rb you install.
|
6
|
-
VERSION = '0.
|
6
|
+
VERSION = '0.9.0'
|
7
7
|
|
8
8
|
# The version of Hnswlib included with gem.
|
9
|
-
HSWLIB_VERSION = '0.
|
9
|
+
HSWLIB_VERSION = '0.8.0'
|
10
10
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: hnswlib
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-12-16 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: Hnswlib.rb provides Ruby bindings for the Hnswlib.
|
14
14
|
email:
|
@@ -30,6 +30,7 @@ files:
|
|
30
30
|
- ext/hnswlib/src/hnswlib.h
|
31
31
|
- ext/hnswlib/src/space_ip.h
|
32
32
|
- ext/hnswlib/src/space_l2.h
|
33
|
+
- ext/hnswlib/src/stop_condition.h
|
33
34
|
- ext/hnswlib/src/visited_list_pool.h
|
34
35
|
- lib/hnswlib.rb
|
35
36
|
- lib/hnswlib/version.rb
|
@@ -41,6 +42,7 @@ metadata:
|
|
41
42
|
homepage_uri: https://github.com/yoshoku/hnswlib.rb
|
42
43
|
source_code_uri: https://github.com/yoshoku/hnswlib.rb
|
43
44
|
changelog_uri: https://github.com/yoshoku/hnswlib.rb/blob/main/CHANGELOG.md
|
45
|
+
documentation_uri: https://yoshoku.github.io/hnswlib.rb/doc/
|
44
46
|
rubygems_mfa_required: 'true'
|
45
47
|
post_install_message:
|
46
48
|
rdoc_options: []
|
@@ -57,7 +59,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
57
59
|
- !ruby/object:Gem::Version
|
58
60
|
version: '0'
|
59
61
|
requirements: []
|
60
|
-
rubygems_version: 3.
|
62
|
+
rubygems_version: 3.4.22
|
61
63
|
signing_key:
|
62
64
|
specification_version: 4
|
63
65
|
summary: Ruby bindings for the Hnswlib.
|