hnswlib 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.github/workflows/build.yml +20 -0
- data/.gitignore +18 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +5 -0
- data/CODE_OF_CONDUCT.md +84 -0
- data/Gemfile +10 -0
- data/LICENSE.txt +176 -0
- data/README.md +56 -0
- data/Rakefile +17 -0
- data/ext/hnswlib/extconf.rb +11 -0
- data/ext/hnswlib/hnswlibext.cpp +29 -0
- data/ext/hnswlib/hnswlibext.hpp +420 -0
- data/ext/hnswlib/src/LICENSE +201 -0
- data/ext/hnswlib/src/bruteforce.h +152 -0
- data/ext/hnswlib/src/hnswalg.h +1192 -0
- data/ext/hnswlib/src/hnswlib.h +108 -0
- data/ext/hnswlib/src/space_ip.h +282 -0
- data/ext/hnswlib/src/space_l2.h +281 -0
- data/ext/hnswlib/src/visited_list_pool.h +78 -0
- data/hnswlib.gemspec +35 -0
- data/lib/hnswlib.rb +154 -0
- data/lib/hnswlib/version.rb +9 -0
- metadata +69 -0
@@ -0,0 +1,152 @@
|
|
1
|
+
#pragma once
|
2
|
+
#include <unordered_map>
|
3
|
+
#include <fstream>
|
4
|
+
#include <mutex>
|
5
|
+
#include <algorithm>
|
6
|
+
|
7
|
+
namespace hnswlib {
|
8
|
+
template<typename dist_t>
|
9
|
+
class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
10
|
+
public:
|
11
|
+
BruteforceSearch(SpaceInterface <dist_t> *s) {
|
12
|
+
|
13
|
+
}
|
14
|
+
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location) {
|
15
|
+
loadIndex(location, s);
|
16
|
+
}
|
17
|
+
|
18
|
+
BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
|
19
|
+
maxelements_ = maxElements;
|
20
|
+
data_size_ = s->get_data_size();
|
21
|
+
fstdistfunc_ = s->get_dist_func();
|
22
|
+
dist_func_param_ = s->get_dist_func_param();
|
23
|
+
size_per_element_ = data_size_ + sizeof(labeltype);
|
24
|
+
data_ = (char *) malloc(maxElements * size_per_element_);
|
25
|
+
if (data_ == nullptr)
|
26
|
+
std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
|
27
|
+
cur_element_count = 0;
|
28
|
+
}
|
29
|
+
|
30
|
+
~BruteforceSearch() {
|
31
|
+
free(data_);
|
32
|
+
}
|
33
|
+
|
34
|
+
char *data_;
|
35
|
+
size_t maxelements_;
|
36
|
+
size_t cur_element_count;
|
37
|
+
size_t size_per_element_;
|
38
|
+
|
39
|
+
size_t data_size_;
|
40
|
+
DISTFUNC <dist_t> fstdistfunc_;
|
41
|
+
void *dist_func_param_;
|
42
|
+
std::mutex index_lock;
|
43
|
+
|
44
|
+
std::unordered_map<labeltype,size_t > dict_external_to_internal;
|
45
|
+
|
46
|
+
void addPoint(const void *datapoint, labeltype label) {
|
47
|
+
|
48
|
+
int idx;
|
49
|
+
{
|
50
|
+
std::unique_lock<std::mutex> lock(index_lock);
|
51
|
+
|
52
|
+
|
53
|
+
|
54
|
+
auto search=dict_external_to_internal.find(label);
|
55
|
+
if (search != dict_external_to_internal.end()) {
|
56
|
+
idx=search->second;
|
57
|
+
}
|
58
|
+
else{
|
59
|
+
if (cur_element_count >= maxelements_) {
|
60
|
+
throw std::runtime_error("The number of elements exceeds the specified limit\n");
|
61
|
+
}
|
62
|
+
idx=cur_element_count;
|
63
|
+
dict_external_to_internal[label] = idx;
|
64
|
+
cur_element_count++;
|
65
|
+
}
|
66
|
+
}
|
67
|
+
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
|
68
|
+
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
|
69
|
+
|
70
|
+
|
71
|
+
|
72
|
+
|
73
|
+
};
|
74
|
+
|
75
|
+
void removePoint(labeltype cur_external) {
|
76
|
+
size_t cur_c=dict_external_to_internal[cur_external];
|
77
|
+
|
78
|
+
dict_external_to_internal.erase(cur_external);
|
79
|
+
|
80
|
+
labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
81
|
+
dict_external_to_internal[label]=cur_c;
|
82
|
+
memcpy(data_ + size_per_element_ * cur_c,
|
83
|
+
data_ + size_per_element_ * (cur_element_count-1),
|
84
|
+
data_size_+sizeof(labeltype));
|
85
|
+
cur_element_count--;
|
86
|
+
|
87
|
+
}
|
88
|
+
|
89
|
+
|
90
|
+
std::priority_queue<std::pair<dist_t, labeltype >>
|
91
|
+
searchKnn(const void *query_data, size_t k) const {
|
92
|
+
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
|
93
|
+
if (cur_element_count == 0) return topResults;
|
94
|
+
for (int i = 0; i < k; i++) {
|
95
|
+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
96
|
+
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
|
97
|
+
data_size_))));
|
98
|
+
}
|
99
|
+
dist_t lastdist = topResults.top().first;
|
100
|
+
for (int i = k; i < cur_element_count; i++) {
|
101
|
+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
102
|
+
if (dist <= lastdist) {
|
103
|
+
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
|
104
|
+
data_size_))));
|
105
|
+
if (topResults.size() > k)
|
106
|
+
topResults.pop();
|
107
|
+
lastdist = topResults.top().first;
|
108
|
+
}
|
109
|
+
|
110
|
+
}
|
111
|
+
return topResults;
|
112
|
+
};
|
113
|
+
|
114
|
+
void saveIndex(const std::string &location) {
|
115
|
+
std::ofstream output(location, std::ios::binary);
|
116
|
+
std::streampos position;
|
117
|
+
|
118
|
+
writeBinaryPOD(output, maxelements_);
|
119
|
+
writeBinaryPOD(output, size_per_element_);
|
120
|
+
writeBinaryPOD(output, cur_element_count);
|
121
|
+
|
122
|
+
output.write(data_, maxelements_ * size_per_element_);
|
123
|
+
|
124
|
+
output.close();
|
125
|
+
}
|
126
|
+
|
127
|
+
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
|
128
|
+
|
129
|
+
|
130
|
+
std::ifstream input(location, std::ios::binary);
|
131
|
+
std::streampos position;
|
132
|
+
|
133
|
+
readBinaryPOD(input, maxelements_);
|
134
|
+
readBinaryPOD(input, size_per_element_);
|
135
|
+
readBinaryPOD(input, cur_element_count);
|
136
|
+
|
137
|
+
data_size_ = s->get_data_size();
|
138
|
+
fstdistfunc_ = s->get_dist_func();
|
139
|
+
dist_func_param_ = s->get_dist_func_param();
|
140
|
+
size_per_element_ = data_size_ + sizeof(labeltype);
|
141
|
+
data_ = (char *) malloc(maxelements_ * size_per_element_);
|
142
|
+
if (data_ == nullptr)
|
143
|
+
std::runtime_error("Not enough memory: loadIndex failed to allocate data");
|
144
|
+
|
145
|
+
input.read(data_, maxelements_ * size_per_element_);
|
146
|
+
|
147
|
+
input.close();
|
148
|
+
|
149
|
+
}
|
150
|
+
|
151
|
+
};
|
152
|
+
}
|
@@ -0,0 +1,1192 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include "visited_list_pool.h"
|
4
|
+
#include "hnswlib.h"
|
5
|
+
#include <atomic>
|
6
|
+
#include <random>
|
7
|
+
#include <stdlib.h>
|
8
|
+
#include <assert.h>
|
9
|
+
#include <unordered_set>
|
10
|
+
#include <list>
|
11
|
+
|
12
|
+
namespace hnswlib {
|
13
|
+
typedef unsigned int tableint;
|
14
|
+
typedef unsigned int linklistsizeint;
|
15
|
+
|
16
|
+
template<typename dist_t>
|
17
|
+
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
|
18
|
+
public:
|
19
|
+
static const tableint max_update_element_locks = 65536;
|
20
|
+
HierarchicalNSW(SpaceInterface<dist_t> *s) {
|
21
|
+
|
22
|
+
}
|
23
|
+
|
24
|
+
HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
|
25
|
+
loadIndex(location, s, max_elements);
|
26
|
+
}
|
27
|
+
|
28
|
+
HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
|
29
|
+
link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
|
30
|
+
max_elements_ = max_elements;
|
31
|
+
|
32
|
+
has_deletions_=false;
|
33
|
+
data_size_ = s->get_data_size();
|
34
|
+
fstdistfunc_ = s->get_dist_func();
|
35
|
+
dist_func_param_ = s->get_dist_func_param();
|
36
|
+
M_ = M;
|
37
|
+
maxM_ = M_;
|
38
|
+
maxM0_ = M_ * 2;
|
39
|
+
ef_construction_ = std::max(ef_construction,M_);
|
40
|
+
ef_ = 10;
|
41
|
+
|
42
|
+
level_generator_.seed(random_seed);
|
43
|
+
update_probability_generator_.seed(random_seed + 1);
|
44
|
+
|
45
|
+
size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
|
46
|
+
size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
|
47
|
+
offsetData_ = size_links_level0_;
|
48
|
+
label_offset_ = size_links_level0_ + data_size_;
|
49
|
+
offsetLevel0_ = 0;
|
50
|
+
|
51
|
+
data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
|
52
|
+
if (data_level0_memory_ == nullptr)
|
53
|
+
throw std::runtime_error("Not enough memory");
|
54
|
+
|
55
|
+
cur_element_count = 0;
|
56
|
+
|
57
|
+
visited_list_pool_ = new VisitedListPool(1, max_elements);
|
58
|
+
|
59
|
+
|
60
|
+
|
61
|
+
//initializations for special treatment of the first node
|
62
|
+
enterpoint_node_ = -1;
|
63
|
+
maxlevel_ = -1;
|
64
|
+
|
65
|
+
linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
|
66
|
+
if (linkLists_ == nullptr)
|
67
|
+
throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
|
68
|
+
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
|
69
|
+
mult_ = 1 / log(1.0 * M_);
|
70
|
+
revSize_ = 1.0 / mult_;
|
71
|
+
}
|
72
|
+
|
73
|
+
struct CompareByFirst {
|
74
|
+
constexpr bool operator()(std::pair<dist_t, tableint> const &a,
|
75
|
+
std::pair<dist_t, tableint> const &b) const noexcept {
|
76
|
+
return a.first < b.first;
|
77
|
+
}
|
78
|
+
};
|
79
|
+
|
80
|
+
~HierarchicalNSW() {
|
81
|
+
|
82
|
+
free(data_level0_memory_);
|
83
|
+
for (tableint i = 0; i < cur_element_count; i++) {
|
84
|
+
if (element_levels_[i] > 0)
|
85
|
+
free(linkLists_[i]);
|
86
|
+
}
|
87
|
+
free(linkLists_);
|
88
|
+
delete visited_list_pool_;
|
89
|
+
}
|
90
|
+
|
91
|
+
size_t max_elements_;
|
92
|
+
size_t cur_element_count;
|
93
|
+
size_t size_data_per_element_;
|
94
|
+
size_t size_links_per_element_;
|
95
|
+
|
96
|
+
size_t M_;
|
97
|
+
size_t maxM_;
|
98
|
+
size_t maxM0_;
|
99
|
+
size_t ef_construction_;
|
100
|
+
|
101
|
+
double mult_, revSize_;
|
102
|
+
int maxlevel_;
|
103
|
+
|
104
|
+
|
105
|
+
VisitedListPool *visited_list_pool_;
|
106
|
+
std::mutex cur_element_count_guard_;
|
107
|
+
|
108
|
+
std::vector<std::mutex> link_list_locks_;
|
109
|
+
|
110
|
+
// Locks to prevent race condition during update/insert of an element at same time.
|
111
|
+
// Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
|
112
|
+
std::vector<std::mutex> link_list_update_locks_;
|
113
|
+
tableint enterpoint_node_;
|
114
|
+
|
115
|
+
|
116
|
+
size_t size_links_level0_;
|
117
|
+
size_t offsetData_, offsetLevel0_;
|
118
|
+
|
119
|
+
|
120
|
+
char *data_level0_memory_;
|
121
|
+
char **linkLists_;
|
122
|
+
std::vector<int> element_levels_;
|
123
|
+
|
124
|
+
size_t data_size_;
|
125
|
+
|
126
|
+
bool has_deletions_;
|
127
|
+
|
128
|
+
|
129
|
+
size_t label_offset_;
|
130
|
+
DISTFUNC<dist_t> fstdistfunc_;
|
131
|
+
void *dist_func_param_;
|
132
|
+
std::unordered_map<labeltype, tableint> label_lookup_;
|
133
|
+
|
134
|
+
std::default_random_engine level_generator_;
|
135
|
+
std::default_random_engine update_probability_generator_;
|
136
|
+
|
137
|
+
inline labeltype getExternalLabel(tableint internal_id) const {
|
138
|
+
labeltype return_label;
|
139
|
+
memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
|
140
|
+
return return_label;
|
141
|
+
}
|
142
|
+
|
143
|
+
inline void setExternalLabel(tableint internal_id, labeltype label) const {
|
144
|
+
memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
|
145
|
+
}
|
146
|
+
|
147
|
+
inline labeltype *getExternalLabeLp(tableint internal_id) const {
|
148
|
+
return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
|
149
|
+
}
|
150
|
+
|
151
|
+
inline char *getDataByInternalId(tableint internal_id) const {
|
152
|
+
return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
|
153
|
+
}
|
154
|
+
|
155
|
+
int getRandomLevel(double reverse_size) {
|
156
|
+
std::uniform_real_distribution<double> distribution(0.0, 1.0);
|
157
|
+
double r = -log(distribution(level_generator_)) * reverse_size;
|
158
|
+
return (int) r;
|
159
|
+
}
|
160
|
+
|
161
|
+
|
162
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
|
163
|
+
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
|
164
|
+
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
|
165
|
+
vl_type *visited_array = vl->mass;
|
166
|
+
vl_type visited_array_tag = vl->curV;
|
167
|
+
|
168
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
|
169
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
|
170
|
+
|
171
|
+
dist_t lowerBound;
|
172
|
+
if (!isMarkedDeleted(ep_id)) {
|
173
|
+
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
|
174
|
+
top_candidates.emplace(dist, ep_id);
|
175
|
+
lowerBound = dist;
|
176
|
+
candidateSet.emplace(-dist, ep_id);
|
177
|
+
} else {
|
178
|
+
lowerBound = std::numeric_limits<dist_t>::max();
|
179
|
+
candidateSet.emplace(-lowerBound, ep_id);
|
180
|
+
}
|
181
|
+
visited_array[ep_id] = visited_array_tag;
|
182
|
+
|
183
|
+
while (!candidateSet.empty()) {
|
184
|
+
std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
|
185
|
+
if ((-curr_el_pair.first) > lowerBound) {
|
186
|
+
break;
|
187
|
+
}
|
188
|
+
candidateSet.pop();
|
189
|
+
|
190
|
+
tableint curNodeNum = curr_el_pair.second;
|
191
|
+
|
192
|
+
std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
|
193
|
+
|
194
|
+
int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
|
195
|
+
if (layer == 0) {
|
196
|
+
data = (int*)get_linklist0(curNodeNum);
|
197
|
+
} else {
|
198
|
+
data = (int*)get_linklist(curNodeNum, layer);
|
199
|
+
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
|
200
|
+
}
|
201
|
+
size_t size = getListCount((linklistsizeint*)data);
|
202
|
+
tableint *datal = (tableint *) (data + 1);
|
203
|
+
#ifdef USE_SSE
|
204
|
+
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
|
205
|
+
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
|
206
|
+
_mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
|
207
|
+
_mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
|
208
|
+
#endif
|
209
|
+
|
210
|
+
for (size_t j = 0; j < size; j++) {
|
211
|
+
tableint candidate_id = *(datal + j);
|
212
|
+
// if (candidate_id == 0) continue;
|
213
|
+
#ifdef USE_SSE
|
214
|
+
_mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
|
215
|
+
_mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
|
216
|
+
#endif
|
217
|
+
if (visited_array[candidate_id] == visited_array_tag) continue;
|
218
|
+
visited_array[candidate_id] = visited_array_tag;
|
219
|
+
char *currObj1 = (getDataByInternalId(candidate_id));
|
220
|
+
|
221
|
+
dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
|
222
|
+
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
|
223
|
+
candidateSet.emplace(-dist1, candidate_id);
|
224
|
+
#ifdef USE_SSE
|
225
|
+
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
|
226
|
+
#endif
|
227
|
+
|
228
|
+
if (!isMarkedDeleted(candidate_id))
|
229
|
+
top_candidates.emplace(dist1, candidate_id);
|
230
|
+
|
231
|
+
if (top_candidates.size() > ef_construction_)
|
232
|
+
top_candidates.pop();
|
233
|
+
|
234
|
+
if (!top_candidates.empty())
|
235
|
+
lowerBound = top_candidates.top().first;
|
236
|
+
}
|
237
|
+
}
|
238
|
+
}
|
239
|
+
visited_list_pool_->releaseVisitedList(vl);
|
240
|
+
|
241
|
+
return top_candidates;
|
242
|
+
}
|
243
|
+
|
244
|
+
mutable std::atomic<long> metric_distance_computations;
|
245
|
+
mutable std::atomic<long> metric_hops;
|
246
|
+
|
247
|
+
template <bool has_deletions, bool collect_metrics=false>
|
248
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
|
249
|
+
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
|
250
|
+
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
|
251
|
+
vl_type *visited_array = vl->mass;
|
252
|
+
vl_type visited_array_tag = vl->curV;
|
253
|
+
|
254
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
|
255
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
|
256
|
+
|
257
|
+
dist_t lowerBound;
|
258
|
+
if (!has_deletions || !isMarkedDeleted(ep_id)) {
|
259
|
+
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
|
260
|
+
lowerBound = dist;
|
261
|
+
top_candidates.emplace(dist, ep_id);
|
262
|
+
candidate_set.emplace(-dist, ep_id);
|
263
|
+
} else {
|
264
|
+
lowerBound = std::numeric_limits<dist_t>::max();
|
265
|
+
candidate_set.emplace(-lowerBound, ep_id);
|
266
|
+
}
|
267
|
+
|
268
|
+
visited_array[ep_id] = visited_array_tag;
|
269
|
+
|
270
|
+
while (!candidate_set.empty()) {
|
271
|
+
|
272
|
+
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
|
273
|
+
|
274
|
+
if ((-current_node_pair.first) > lowerBound) {
|
275
|
+
break;
|
276
|
+
}
|
277
|
+
candidate_set.pop();
|
278
|
+
|
279
|
+
tableint current_node_id = current_node_pair.second;
|
280
|
+
int *data = (int *) get_linklist0(current_node_id);
|
281
|
+
size_t size = getListCount((linklistsizeint*)data);
|
282
|
+
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
|
283
|
+
if(collect_metrics){
|
284
|
+
metric_hops++;
|
285
|
+
metric_distance_computations+=size;
|
286
|
+
}
|
287
|
+
|
288
|
+
#ifdef USE_SSE
|
289
|
+
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
|
290
|
+
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
|
291
|
+
_mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
|
292
|
+
_mm_prefetch((char *) (data + 2), _MM_HINT_T0);
|
293
|
+
#endif
|
294
|
+
|
295
|
+
for (size_t j = 1; j <= size; j++) {
|
296
|
+
int candidate_id = *(data + j);
|
297
|
+
// if (candidate_id == 0) continue;
|
298
|
+
#ifdef USE_SSE
|
299
|
+
_mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
|
300
|
+
_mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
|
301
|
+
_MM_HINT_T0);////////////
|
302
|
+
#endif
|
303
|
+
if (!(visited_array[candidate_id] == visited_array_tag)) {
|
304
|
+
|
305
|
+
visited_array[candidate_id] = visited_array_tag;
|
306
|
+
|
307
|
+
char *currObj1 = (getDataByInternalId(candidate_id));
|
308
|
+
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
|
309
|
+
|
310
|
+
if (top_candidates.size() < ef || lowerBound > dist) {
|
311
|
+
candidate_set.emplace(-dist, candidate_id);
|
312
|
+
#ifdef USE_SSE
|
313
|
+
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
|
314
|
+
offsetLevel0_,///////////
|
315
|
+
_MM_HINT_T0);////////////////////////
|
316
|
+
#endif
|
317
|
+
|
318
|
+
if (!has_deletions || !isMarkedDeleted(candidate_id))
|
319
|
+
top_candidates.emplace(dist, candidate_id);
|
320
|
+
|
321
|
+
if (top_candidates.size() > ef)
|
322
|
+
top_candidates.pop();
|
323
|
+
|
324
|
+
if (!top_candidates.empty())
|
325
|
+
lowerBound = top_candidates.top().first;
|
326
|
+
}
|
327
|
+
}
|
328
|
+
}
|
329
|
+
}
|
330
|
+
|
331
|
+
visited_list_pool_->releaseVisitedList(vl);
|
332
|
+
return top_candidates;
|
333
|
+
}
|
334
|
+
|
335
|
+
void getNeighborsByHeuristic2(
|
336
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
|
337
|
+
const size_t M) {
|
338
|
+
if (top_candidates.size() < M) {
|
339
|
+
return;
|
340
|
+
}
|
341
|
+
|
342
|
+
std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
|
343
|
+
std::vector<std::pair<dist_t, tableint>> return_list;
|
344
|
+
while (top_candidates.size() > 0) {
|
345
|
+
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
|
346
|
+
top_candidates.pop();
|
347
|
+
}
|
348
|
+
|
349
|
+
while (queue_closest.size()) {
|
350
|
+
if (return_list.size() >= M)
|
351
|
+
break;
|
352
|
+
std::pair<dist_t, tableint> curent_pair = queue_closest.top();
|
353
|
+
dist_t dist_to_query = -curent_pair.first;
|
354
|
+
queue_closest.pop();
|
355
|
+
bool good = true;
|
356
|
+
|
357
|
+
for (std::pair<dist_t, tableint> second_pair : return_list) {
|
358
|
+
dist_t curdist =
|
359
|
+
fstdistfunc_(getDataByInternalId(second_pair.second),
|
360
|
+
getDataByInternalId(curent_pair.second),
|
361
|
+
dist_func_param_);;
|
362
|
+
if (curdist < dist_to_query) {
|
363
|
+
good = false;
|
364
|
+
break;
|
365
|
+
}
|
366
|
+
}
|
367
|
+
if (good) {
|
368
|
+
return_list.push_back(curent_pair);
|
369
|
+
}
|
370
|
+
}
|
371
|
+
|
372
|
+
for (std::pair<dist_t, tableint> curent_pair : return_list) {
|
373
|
+
top_candidates.emplace(-curent_pair.first, curent_pair.second);
|
374
|
+
}
|
375
|
+
}
|
376
|
+
|
377
|
+
|
378
|
+
linklistsizeint *get_linklist0(tableint internal_id) const {
|
379
|
+
return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
|
380
|
+
};
|
381
|
+
|
382
|
+
linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
|
383
|
+
return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
|
384
|
+
};
|
385
|
+
|
386
|
+
linklistsizeint *get_linklist(tableint internal_id, int level) const {
|
387
|
+
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
|
388
|
+
};
|
389
|
+
|
390
|
+
linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
|
391
|
+
return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
|
392
|
+
};
|
393
|
+
|
394
|
+
tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
|
395
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
|
396
|
+
int level, bool isUpdate) {
|
397
|
+
size_t Mcurmax = level ? maxM_ : maxM0_;
|
398
|
+
getNeighborsByHeuristic2(top_candidates, M_);
|
399
|
+
if (top_candidates.size() > M_)
|
400
|
+
throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
|
401
|
+
|
402
|
+
std::vector<tableint> selectedNeighbors;
|
403
|
+
selectedNeighbors.reserve(M_);
|
404
|
+
while (top_candidates.size() > 0) {
|
405
|
+
selectedNeighbors.push_back(top_candidates.top().second);
|
406
|
+
top_candidates.pop();
|
407
|
+
}
|
408
|
+
|
409
|
+
tableint next_closest_entry_point = selectedNeighbors.back();
|
410
|
+
|
411
|
+
{
|
412
|
+
linklistsizeint *ll_cur;
|
413
|
+
if (level == 0)
|
414
|
+
ll_cur = get_linklist0(cur_c);
|
415
|
+
else
|
416
|
+
ll_cur = get_linklist(cur_c, level);
|
417
|
+
|
418
|
+
if (*ll_cur && !isUpdate) {
|
419
|
+
throw std::runtime_error("The newly inserted element should have blank link list");
|
420
|
+
}
|
421
|
+
setListCount(ll_cur,selectedNeighbors.size());
|
422
|
+
tableint *data = (tableint *) (ll_cur + 1);
|
423
|
+
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
|
424
|
+
if (data[idx] && !isUpdate)
|
425
|
+
throw std::runtime_error("Possible memory corruption");
|
426
|
+
if (level > element_levels_[selectedNeighbors[idx]])
|
427
|
+
throw std::runtime_error("Trying to make a link on a non-existent level");
|
428
|
+
|
429
|
+
data[idx] = selectedNeighbors[idx];
|
430
|
+
|
431
|
+
}
|
432
|
+
}
|
433
|
+
|
434
|
+
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
|
435
|
+
|
436
|
+
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
|
437
|
+
|
438
|
+
linklistsizeint *ll_other;
|
439
|
+
if (level == 0)
|
440
|
+
ll_other = get_linklist0(selectedNeighbors[idx]);
|
441
|
+
else
|
442
|
+
ll_other = get_linklist(selectedNeighbors[idx], level);
|
443
|
+
|
444
|
+
size_t sz_link_list_other = getListCount(ll_other);
|
445
|
+
|
446
|
+
if (sz_link_list_other > Mcurmax)
|
447
|
+
throw std::runtime_error("Bad value of sz_link_list_other");
|
448
|
+
if (selectedNeighbors[idx] == cur_c)
|
449
|
+
throw std::runtime_error("Trying to connect an element to itself");
|
450
|
+
if (level > element_levels_[selectedNeighbors[idx]])
|
451
|
+
throw std::runtime_error("Trying to make a link on a non-existent level");
|
452
|
+
|
453
|
+
tableint *data = (tableint *) (ll_other + 1);
|
454
|
+
|
455
|
+
bool is_cur_c_present = false;
|
456
|
+
if (isUpdate) {
|
457
|
+
for (size_t j = 0; j < sz_link_list_other; j++) {
|
458
|
+
if (data[j] == cur_c) {
|
459
|
+
is_cur_c_present = true;
|
460
|
+
break;
|
461
|
+
}
|
462
|
+
}
|
463
|
+
}
|
464
|
+
|
465
|
+
// If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
|
466
|
+
if (!is_cur_c_present) {
|
467
|
+
if (sz_link_list_other < Mcurmax) {
|
468
|
+
data[sz_link_list_other] = cur_c;
|
469
|
+
setListCount(ll_other, sz_link_list_other + 1);
|
470
|
+
} else {
|
471
|
+
// finding the "weakest" element to replace it with the new one
|
472
|
+
dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
|
473
|
+
dist_func_param_);
|
474
|
+
// Heuristic:
|
475
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
|
476
|
+
candidates.emplace(d_max, cur_c);
|
477
|
+
|
478
|
+
for (size_t j = 0; j < sz_link_list_other; j++) {
|
479
|
+
candidates.emplace(
|
480
|
+
fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
|
481
|
+
dist_func_param_), data[j]);
|
482
|
+
}
|
483
|
+
|
484
|
+
getNeighborsByHeuristic2(candidates, Mcurmax);
|
485
|
+
|
486
|
+
int indx = 0;
|
487
|
+
while (candidates.size() > 0) {
|
488
|
+
data[indx] = candidates.top().second;
|
489
|
+
candidates.pop();
|
490
|
+
indx++;
|
491
|
+
}
|
492
|
+
|
493
|
+
setListCount(ll_other, indx);
|
494
|
+
// Nearest K:
|
495
|
+
/*int indx = -1;
|
496
|
+
for (int j = 0; j < sz_link_list_other; j++) {
|
497
|
+
dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
|
498
|
+
if (d > d_max) {
|
499
|
+
indx = j;
|
500
|
+
d_max = d;
|
501
|
+
}
|
502
|
+
}
|
503
|
+
if (indx >= 0) {
|
504
|
+
data[indx] = cur_c;
|
505
|
+
} */
|
506
|
+
}
|
507
|
+
}
|
508
|
+
}
|
509
|
+
|
510
|
+
return next_closest_entry_point;
|
511
|
+
}
|
512
|
+
|
513
|
+
std::mutex global;
|
514
|
+
size_t ef_;
|
515
|
+
|
516
|
+
void setEf(size_t ef) {
|
517
|
+
ef_ = ef;
|
518
|
+
}
|
519
|
+
|
520
|
+
|
521
|
+
std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
|
522
|
+
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
|
523
|
+
if (cur_element_count == 0) return top_candidates;
|
524
|
+
tableint currObj = enterpoint_node_;
|
525
|
+
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
|
526
|
+
|
527
|
+
for (size_t level = maxlevel_; level > 0; level--) {
|
528
|
+
bool changed = true;
|
529
|
+
while (changed) {
|
530
|
+
changed = false;
|
531
|
+
int *data;
|
532
|
+
data = (int *) get_linklist(currObj,level);
|
533
|
+
int size = getListCount(data);
|
534
|
+
tableint *datal = (tableint *) (data + 1);
|
535
|
+
for (int i = 0; i < size; i++) {
|
536
|
+
tableint cand = datal[i];
|
537
|
+
if (cand < 0 || cand > max_elements_)
|
538
|
+
throw std::runtime_error("cand error");
|
539
|
+
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
|
540
|
+
|
541
|
+
if (d < curdist) {
|
542
|
+
curdist = d;
|
543
|
+
currObj = cand;
|
544
|
+
changed = true;
|
545
|
+
}
|
546
|
+
}
|
547
|
+
}
|
548
|
+
}
|
549
|
+
|
550
|
+
if (has_deletions_) {
|
551
|
+
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
|
552
|
+
ef_);
|
553
|
+
top_candidates.swap(top_candidates1);
|
554
|
+
}
|
555
|
+
else{
|
556
|
+
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
|
557
|
+
ef_);
|
558
|
+
top_candidates.swap(top_candidates1);
|
559
|
+
}
|
560
|
+
|
561
|
+
while (top_candidates.size() > k) {
|
562
|
+
top_candidates.pop();
|
563
|
+
}
|
564
|
+
return top_candidates;
|
565
|
+
};
|
566
|
+
|
567
|
+
void resizeIndex(size_t new_max_elements){
|
568
|
+
if (new_max_elements<cur_element_count)
|
569
|
+
throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
|
570
|
+
|
571
|
+
|
572
|
+
delete visited_list_pool_;
|
573
|
+
visited_list_pool_ = new VisitedListPool(1, new_max_elements);
|
574
|
+
|
575
|
+
|
576
|
+
element_levels_.resize(new_max_elements);
|
577
|
+
|
578
|
+
std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
|
579
|
+
|
580
|
+
// Reallocate base layer
|
581
|
+
char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
|
582
|
+
if (data_level0_memory_new == nullptr)
|
583
|
+
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
|
584
|
+
data_level0_memory_ = data_level0_memory_new;
|
585
|
+
|
586
|
+
// Reallocate all other layers
|
587
|
+
char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
|
588
|
+
if (linkLists_new == nullptr)
|
589
|
+
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
|
590
|
+
linkLists_ = linkLists_new;
|
591
|
+
|
592
|
+
max_elements_ = new_max_elements;
|
593
|
+
}
|
594
|
+
|
595
|
+
void saveIndex(const std::string &location) {
|
596
|
+
std::ofstream output(location, std::ios::binary);
|
597
|
+
std::streampos position;
|
598
|
+
|
599
|
+
writeBinaryPOD(output, offsetLevel0_);
|
600
|
+
writeBinaryPOD(output, max_elements_);
|
601
|
+
writeBinaryPOD(output, cur_element_count);
|
602
|
+
writeBinaryPOD(output, size_data_per_element_);
|
603
|
+
writeBinaryPOD(output, label_offset_);
|
604
|
+
writeBinaryPOD(output, offsetData_);
|
605
|
+
writeBinaryPOD(output, maxlevel_);
|
606
|
+
writeBinaryPOD(output, enterpoint_node_);
|
607
|
+
writeBinaryPOD(output, maxM_);
|
608
|
+
|
609
|
+
writeBinaryPOD(output, maxM0_);
|
610
|
+
writeBinaryPOD(output, M_);
|
611
|
+
writeBinaryPOD(output, mult_);
|
612
|
+
writeBinaryPOD(output, ef_construction_);
|
613
|
+
|
614
|
+
output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
|
615
|
+
|
616
|
+
for (size_t i = 0; i < cur_element_count; i++) {
|
617
|
+
unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
|
618
|
+
writeBinaryPOD(output, linkListSize);
|
619
|
+
if (linkListSize)
|
620
|
+
output.write(linkLists_[i], linkListSize);
|
621
|
+
}
|
622
|
+
output.close();
|
623
|
+
}
|
624
|
+
|
625
|
+
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
|
626
|
+
|
627
|
+
|
628
|
+
std::ifstream input(location, std::ios::binary);
|
629
|
+
|
630
|
+
if (!input.is_open())
|
631
|
+
throw std::runtime_error("Cannot open file");
|
632
|
+
|
633
|
+
// get file size:
|
634
|
+
input.seekg(0,input.end);
|
635
|
+
std::streampos total_filesize=input.tellg();
|
636
|
+
input.seekg(0,input.beg);
|
637
|
+
|
638
|
+
readBinaryPOD(input, offsetLevel0_);
|
639
|
+
readBinaryPOD(input, max_elements_);
|
640
|
+
readBinaryPOD(input, cur_element_count);
|
641
|
+
|
642
|
+
size_t max_elements=max_elements_i;
|
643
|
+
if(max_elements < cur_element_count)
|
644
|
+
max_elements = max_elements_;
|
645
|
+
max_elements_ = max_elements;
|
646
|
+
readBinaryPOD(input, size_data_per_element_);
|
647
|
+
readBinaryPOD(input, label_offset_);
|
648
|
+
readBinaryPOD(input, offsetData_);
|
649
|
+
readBinaryPOD(input, maxlevel_);
|
650
|
+
readBinaryPOD(input, enterpoint_node_);
|
651
|
+
|
652
|
+
readBinaryPOD(input, maxM_);
|
653
|
+
readBinaryPOD(input, maxM0_);
|
654
|
+
readBinaryPOD(input, M_);
|
655
|
+
readBinaryPOD(input, mult_);
|
656
|
+
readBinaryPOD(input, ef_construction_);
|
657
|
+
|
658
|
+
|
659
|
+
data_size_ = s->get_data_size();
|
660
|
+
fstdistfunc_ = s->get_dist_func();
|
661
|
+
dist_func_param_ = s->get_dist_func_param();
|
662
|
+
|
663
|
+
auto pos=input.tellg();
|
664
|
+
|
665
|
+
|
666
|
+
/// Optional - check if index is ok:
|
667
|
+
|
668
|
+
input.seekg(cur_element_count * size_data_per_element_,input.cur);
|
669
|
+
for (size_t i = 0; i < cur_element_count; i++) {
|
670
|
+
if(input.tellg() < 0 || input.tellg()>=total_filesize){
|
671
|
+
throw std::runtime_error("Index seems to be corrupted or unsupported");
|
672
|
+
}
|
673
|
+
|
674
|
+
unsigned int linkListSize;
|
675
|
+
readBinaryPOD(input, linkListSize);
|
676
|
+
if (linkListSize != 0) {
|
677
|
+
input.seekg(linkListSize,input.cur);
|
678
|
+
}
|
679
|
+
}
|
680
|
+
|
681
|
+
// throw exception if it either corrupted or old index
|
682
|
+
if(input.tellg()!=total_filesize)
|
683
|
+
throw std::runtime_error("Index seems to be corrupted or unsupported");
|
684
|
+
|
685
|
+
input.clear();
|
686
|
+
|
687
|
+
/// Optional check end
|
688
|
+
|
689
|
+
input.seekg(pos,input.beg);
|
690
|
+
|
691
|
+
|
692
|
+
data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
|
693
|
+
if (data_level0_memory_ == nullptr)
|
694
|
+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
|
695
|
+
input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
|
696
|
+
|
697
|
+
|
698
|
+
|
699
|
+
|
700
|
+
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
|
701
|
+
|
702
|
+
|
703
|
+
size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
|
704
|
+
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
|
705
|
+
std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);
|
706
|
+
|
707
|
+
|
708
|
+
visited_list_pool_ = new VisitedListPool(1, max_elements);
|
709
|
+
|
710
|
+
|
711
|
+
linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
|
712
|
+
if (linkLists_ == nullptr)
|
713
|
+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
|
714
|
+
element_levels_ = std::vector<int>(max_elements);
|
715
|
+
revSize_ = 1.0 / mult_;
|
716
|
+
ef_ = 10;
|
717
|
+
for (size_t i = 0; i < cur_element_count; i++) {
|
718
|
+
label_lookup_[getExternalLabel(i)]=i;
|
719
|
+
unsigned int linkListSize;
|
720
|
+
readBinaryPOD(input, linkListSize);
|
721
|
+
if (linkListSize == 0) {
|
722
|
+
element_levels_[i] = 0;
|
723
|
+
|
724
|
+
linkLists_[i] = nullptr;
|
725
|
+
} else {
|
726
|
+
element_levels_[i] = linkListSize / size_links_per_element_;
|
727
|
+
linkLists_[i] = (char *) malloc(linkListSize);
|
728
|
+
if (linkLists_[i] == nullptr)
|
729
|
+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
|
730
|
+
input.read(linkLists_[i], linkListSize);
|
731
|
+
}
|
732
|
+
}
|
733
|
+
|
734
|
+
has_deletions_=false;
|
735
|
+
|
736
|
+
for (size_t i = 0; i < cur_element_count; i++) {
|
737
|
+
if(isMarkedDeleted(i))
|
738
|
+
has_deletions_=true;
|
739
|
+
}
|
740
|
+
|
741
|
+
input.close();
|
742
|
+
|
743
|
+
return;
|
744
|
+
}
|
745
|
+
|
746
|
+
template<typename data_t>
|
747
|
+
std::vector<data_t> getDataByLabel(labeltype label)
|
748
|
+
{
|
749
|
+
tableint label_c;
|
750
|
+
auto search = label_lookup_.find(label);
|
751
|
+
if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
|
752
|
+
throw std::runtime_error("Label not found");
|
753
|
+
}
|
754
|
+
label_c = search->second;
|
755
|
+
|
756
|
+
char* data_ptrv = getDataByInternalId(label_c);
|
757
|
+
size_t dim = *((size_t *) dist_func_param_);
|
758
|
+
std::vector<data_t> data;
|
759
|
+
data_t* data_ptr = (data_t*) data_ptrv;
|
760
|
+
for (int i = 0; i < dim; i++) {
|
761
|
+
data.push_back(*data_ptr);
|
762
|
+
data_ptr += 1;
|
763
|
+
}
|
764
|
+
return data;
|
765
|
+
}
|
766
|
+
|
767
|
+
static const unsigned char DELETE_MARK = 0x01;
|
768
|
+
// static const unsigned char REUSE_MARK = 0x10;
|
769
|
+
/**
|
770
|
+
* Marks an element with the given label deleted, does NOT really change the current graph.
|
771
|
+
* @param label
|
772
|
+
*/
|
773
|
+
void markDelete(labeltype label)
|
774
|
+
{
|
775
|
+
has_deletions_=true;
|
776
|
+
auto search = label_lookup_.find(label);
|
777
|
+
if (search == label_lookup_.end()) {
|
778
|
+
throw std::runtime_error("Label not found");
|
779
|
+
}
|
780
|
+
markDeletedInternal(search->second);
|
781
|
+
}
|
782
|
+
|
783
|
+
/**
|
784
|
+
* Uses the first 8 bits of the memory for the linked list to store the mark,
|
785
|
+
* whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
|
786
|
+
* @param internalId
|
787
|
+
*/
|
788
|
+
void markDeletedInternal(tableint internalId) {
|
789
|
+
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
|
790
|
+
*ll_cur |= DELETE_MARK;
|
791
|
+
}
|
792
|
+
|
793
|
+
/**
|
794
|
+
* Remove the deleted mark of the node.
|
795
|
+
* @param internalId
|
796
|
+
*/
|
797
|
+
void unmarkDeletedInternal(tableint internalId) {
|
798
|
+
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
|
799
|
+
*ll_cur &= ~DELETE_MARK;
|
800
|
+
}
|
801
|
+
|
802
|
+
/**
|
803
|
+
* Checks the first 8 bits of the memory to see if the element is marked deleted.
|
804
|
+
* @param internalId
|
805
|
+
* @return
|
806
|
+
*/
|
807
|
+
bool isMarkedDeleted(tableint internalId) const {
|
808
|
+
unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
|
809
|
+
return *ll_cur & DELETE_MARK;
|
810
|
+
}
|
811
|
+
|
812
|
+
unsigned short int getListCount(linklistsizeint * ptr) const {
|
813
|
+
return *((unsigned short int *)ptr);
|
814
|
+
}
|
815
|
+
|
816
|
+
void setListCount(linklistsizeint * ptr, unsigned short int size) const {
|
817
|
+
*((unsigned short int*)(ptr))=*((unsigned short int *)&size);
|
818
|
+
}
|
819
|
+
|
820
|
+
void addPoint(const void *data_point, labeltype label) {
|
821
|
+
addPoint(data_point, label,-1);
|
822
|
+
}
|
823
|
+
|
824
|
+
void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
|
825
|
+
// update the feature vector associated with existing point with new vector
|
826
|
+
memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
|
827
|
+
|
828
|
+
int maxLevelCopy = maxlevel_;
|
829
|
+
tableint entryPointCopy = enterpoint_node_;
|
830
|
+
// If point to be updated is entry point and graph just contains single element then just return.
|
831
|
+
if (entryPointCopy == internalId && cur_element_count == 1)
|
832
|
+
return;
|
833
|
+
|
834
|
+
int elemLevel = element_levels_[internalId];
|
835
|
+
std::uniform_real_distribution<float> distribution(0.0, 1.0);
|
836
|
+
for (int layer = 0; layer <= elemLevel; layer++) {
|
837
|
+
std::unordered_set<tableint> sCand;
|
838
|
+
std::unordered_set<tableint> sNeigh;
|
839
|
+
std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
|
840
|
+
if (listOneHop.size() == 0)
|
841
|
+
continue;
|
842
|
+
|
843
|
+
sCand.insert(internalId);
|
844
|
+
|
845
|
+
for (auto&& elOneHop : listOneHop) {
|
846
|
+
sCand.insert(elOneHop);
|
847
|
+
|
848
|
+
if (distribution(update_probability_generator_) > updateNeighborProbability)
|
849
|
+
continue;
|
850
|
+
|
851
|
+
sNeigh.insert(elOneHop);
|
852
|
+
|
853
|
+
std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
|
854
|
+
for (auto&& elTwoHop : listTwoHop) {
|
855
|
+
sCand.insert(elTwoHop);
|
856
|
+
}
|
857
|
+
}
|
858
|
+
|
859
|
+
for (auto&& neigh : sNeigh) {
|
860
|
+
// if (neigh == internalId)
|
861
|
+
// continue;
|
862
|
+
|
863
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
|
864
|
+
size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
|
865
|
+
size_t elementsToKeep = std::min(ef_construction_, size);
|
866
|
+
for (auto&& cand : sCand) {
|
867
|
+
if (cand == neigh)
|
868
|
+
continue;
|
869
|
+
|
870
|
+
dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
|
871
|
+
if (candidates.size() < elementsToKeep) {
|
872
|
+
candidates.emplace(distance, cand);
|
873
|
+
} else {
|
874
|
+
if (distance < candidates.top().first) {
|
875
|
+
candidates.pop();
|
876
|
+
candidates.emplace(distance, cand);
|
877
|
+
}
|
878
|
+
}
|
879
|
+
}
|
880
|
+
|
881
|
+
// Retrieve neighbours using heuristic and set connections.
|
882
|
+
getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
|
883
|
+
|
884
|
+
{
|
885
|
+
std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
|
886
|
+
linklistsizeint *ll_cur;
|
887
|
+
ll_cur = get_linklist_at_level(neigh, layer);
|
888
|
+
size_t candSize = candidates.size();
|
889
|
+
setListCount(ll_cur, candSize);
|
890
|
+
tableint *data = (tableint *) (ll_cur + 1);
|
891
|
+
for (size_t idx = 0; idx < candSize; idx++) {
|
892
|
+
data[idx] = candidates.top().second;
|
893
|
+
candidates.pop();
|
894
|
+
}
|
895
|
+
}
|
896
|
+
}
|
897
|
+
}
|
898
|
+
|
899
|
+
repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
|
900
|
+
};
|
901
|
+
|
902
|
+
void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
|
903
|
+
tableint currObj = entryPointInternalId;
|
904
|
+
if (dataPointLevel < maxLevel) {
|
905
|
+
dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
|
906
|
+
for (int level = maxLevel; level > dataPointLevel; level--) {
|
907
|
+
bool changed = true;
|
908
|
+
while (changed) {
|
909
|
+
changed = false;
|
910
|
+
unsigned int *data;
|
911
|
+
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
|
912
|
+
data = get_linklist_at_level(currObj,level);
|
913
|
+
int size = getListCount(data);
|
914
|
+
tableint *datal = (tableint *) (data + 1);
|
915
|
+
#ifdef USE_SSE
|
916
|
+
_mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
|
917
|
+
#endif
|
918
|
+
for (int i = 0; i < size; i++) {
|
919
|
+
#ifdef USE_SSE
|
920
|
+
_mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
|
921
|
+
#endif
|
922
|
+
tableint cand = datal[i];
|
923
|
+
dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
|
924
|
+
if (d < curdist) {
|
925
|
+
curdist = d;
|
926
|
+
currObj = cand;
|
927
|
+
changed = true;
|
928
|
+
}
|
929
|
+
}
|
930
|
+
}
|
931
|
+
}
|
932
|
+
}
|
933
|
+
|
934
|
+
if (dataPointLevel > maxLevel)
|
935
|
+
throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
|
936
|
+
|
937
|
+
for (int level = dataPointLevel; level >= 0; level--) {
|
938
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
|
939
|
+
currObj, dataPoint, level);
|
940
|
+
|
941
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
|
942
|
+
while (topCandidates.size() > 0) {
|
943
|
+
if (topCandidates.top().second != dataPointInternalId)
|
944
|
+
filteredTopCandidates.push(topCandidates.top());
|
945
|
+
|
946
|
+
topCandidates.pop();
|
947
|
+
}
|
948
|
+
|
949
|
+
// Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
|
950
|
+
// To prevent self loops, the `topCandidates` is filtered and thus can be empty.
|
951
|
+
if (filteredTopCandidates.size() > 0) {
|
952
|
+
bool epDeleted = isMarkedDeleted(entryPointInternalId);
|
953
|
+
if (epDeleted) {
|
954
|
+
filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
|
955
|
+
if (filteredTopCandidates.size() > ef_construction_)
|
956
|
+
filteredTopCandidates.pop();
|
957
|
+
}
|
958
|
+
|
959
|
+
currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
|
960
|
+
}
|
961
|
+
}
|
962
|
+
}
|
963
|
+
|
964
|
+
std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
|
965
|
+
std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
|
966
|
+
unsigned int *data = get_linklist_at_level(internalId, level);
|
967
|
+
int size = getListCount(data);
|
968
|
+
std::vector<tableint> result(size);
|
969
|
+
tableint *ll = (tableint *) (data + 1);
|
970
|
+
memcpy(result.data(), ll,size * sizeof(tableint));
|
971
|
+
return result;
|
972
|
+
};
|
973
|
+
|
974
|
+
tableint addPoint(const void *data_point, labeltype label, int level) {
|
975
|
+
|
976
|
+
tableint cur_c = 0;
|
977
|
+
{
|
978
|
+
// Checking if the element with the same label already exists
|
979
|
+
// if so, updating it *instead* of creating a new element.
|
980
|
+
std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
|
981
|
+
auto search = label_lookup_.find(label);
|
982
|
+
if (search != label_lookup_.end()) {
|
983
|
+
tableint existingInternalId = search->second;
|
984
|
+
templock_curr.unlock();
|
985
|
+
|
986
|
+
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
|
987
|
+
|
988
|
+
if (isMarkedDeleted(existingInternalId)) {
|
989
|
+
unmarkDeletedInternal(existingInternalId);
|
990
|
+
}
|
991
|
+
updatePoint(data_point, existingInternalId, 1.0);
|
992
|
+
|
993
|
+
return existingInternalId;
|
994
|
+
}
|
995
|
+
|
996
|
+
if (cur_element_count >= max_elements_) {
|
997
|
+
throw std::runtime_error("The number of elements exceeds the specified limit");
|
998
|
+
};
|
999
|
+
|
1000
|
+
cur_c = cur_element_count;
|
1001
|
+
cur_element_count++;
|
1002
|
+
label_lookup_[label] = cur_c;
|
1003
|
+
}
|
1004
|
+
|
1005
|
+
// Take update lock to prevent race conditions on an element with insertion/update at the same time.
|
1006
|
+
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
|
1007
|
+
std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
|
1008
|
+
int curlevel = getRandomLevel(mult_);
|
1009
|
+
if (level > 0)
|
1010
|
+
curlevel = level;
|
1011
|
+
|
1012
|
+
element_levels_[cur_c] = curlevel;
|
1013
|
+
|
1014
|
+
|
1015
|
+
std::unique_lock <std::mutex> templock(global);
|
1016
|
+
int maxlevelcopy = maxlevel_;
|
1017
|
+
if (curlevel <= maxlevelcopy)
|
1018
|
+
templock.unlock();
|
1019
|
+
tableint currObj = enterpoint_node_;
|
1020
|
+
tableint enterpoint_copy = enterpoint_node_;
|
1021
|
+
|
1022
|
+
|
1023
|
+
memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
|
1024
|
+
|
1025
|
+
// Initialisation of the data and label
|
1026
|
+
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
|
1027
|
+
memcpy(getDataByInternalId(cur_c), data_point, data_size_);
|
1028
|
+
|
1029
|
+
|
1030
|
+
if (curlevel) {
|
1031
|
+
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
|
1032
|
+
if (linkLists_[cur_c] == nullptr)
|
1033
|
+
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
|
1034
|
+
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
|
1035
|
+
}
|
1036
|
+
|
1037
|
+
if ((signed)currObj != -1) {
|
1038
|
+
|
1039
|
+
if (curlevel < maxlevelcopy) {
|
1040
|
+
|
1041
|
+
dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
|
1042
|
+
for (int level = maxlevelcopy; level > curlevel; level--) {
|
1043
|
+
|
1044
|
+
|
1045
|
+
bool changed = true;
|
1046
|
+
while (changed) {
|
1047
|
+
changed = false;
|
1048
|
+
unsigned int *data;
|
1049
|
+
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
|
1050
|
+
data = get_linklist(currObj,level);
|
1051
|
+
int size = getListCount(data);
|
1052
|
+
|
1053
|
+
tableint *datal = (tableint *) (data + 1);
|
1054
|
+
for (int i = 0; i < size; i++) {
|
1055
|
+
tableint cand = datal[i];
|
1056
|
+
if (cand < 0 || cand > max_elements_)
|
1057
|
+
throw std::runtime_error("cand error");
|
1058
|
+
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
|
1059
|
+
if (d < curdist) {
|
1060
|
+
curdist = d;
|
1061
|
+
currObj = cand;
|
1062
|
+
changed = true;
|
1063
|
+
}
|
1064
|
+
}
|
1065
|
+
}
|
1066
|
+
}
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
bool epDeleted = isMarkedDeleted(enterpoint_copy);
|
1070
|
+
for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
|
1071
|
+
if (level > maxlevelcopy || level < 0) // possible?
|
1072
|
+
throw std::runtime_error("Level error");
|
1073
|
+
|
1074
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
|
1075
|
+
currObj, data_point, level);
|
1076
|
+
if (epDeleted) {
|
1077
|
+
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
|
1078
|
+
if (top_candidates.size() > ef_construction_)
|
1079
|
+
top_candidates.pop();
|
1080
|
+
}
|
1081
|
+
currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
|
1082
|
+
}
|
1083
|
+
|
1084
|
+
|
1085
|
+
} else {
|
1086
|
+
// Do nothing for the first element
|
1087
|
+
enterpoint_node_ = 0;
|
1088
|
+
maxlevel_ = curlevel;
|
1089
|
+
|
1090
|
+
}
|
1091
|
+
|
1092
|
+
//Releasing lock for the maximum level
|
1093
|
+
if (curlevel > maxlevelcopy) {
|
1094
|
+
enterpoint_node_ = cur_c;
|
1095
|
+
maxlevel_ = curlevel;
|
1096
|
+
}
|
1097
|
+
return cur_c;
|
1098
|
+
};
|
1099
|
+
|
1100
|
+
std::priority_queue<std::pair<dist_t, labeltype >>
|
1101
|
+
searchKnn(const void *query_data, size_t k) const {
|
1102
|
+
std::priority_queue<std::pair<dist_t, labeltype >> result;
|
1103
|
+
if (cur_element_count == 0) return result;
|
1104
|
+
|
1105
|
+
tableint currObj = enterpoint_node_;
|
1106
|
+
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
|
1107
|
+
|
1108
|
+
for (int level = maxlevel_; level > 0; level--) {
|
1109
|
+
bool changed = true;
|
1110
|
+
while (changed) {
|
1111
|
+
changed = false;
|
1112
|
+
unsigned int *data;
|
1113
|
+
|
1114
|
+
data = (unsigned int *) get_linklist(currObj, level);
|
1115
|
+
int size = getListCount(data);
|
1116
|
+
metric_hops++;
|
1117
|
+
metric_distance_computations+=size;
|
1118
|
+
|
1119
|
+
tableint *datal = (tableint *) (data + 1);
|
1120
|
+
for (int i = 0; i < size; i++) {
|
1121
|
+
tableint cand = datal[i];
|
1122
|
+
if (cand < 0 || cand > max_elements_)
|
1123
|
+
throw std::runtime_error("cand error");
|
1124
|
+
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
|
1125
|
+
|
1126
|
+
if (d < curdist) {
|
1127
|
+
curdist = d;
|
1128
|
+
currObj = cand;
|
1129
|
+
changed = true;
|
1130
|
+
}
|
1131
|
+
}
|
1132
|
+
}
|
1133
|
+
}
|
1134
|
+
|
1135
|
+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
|
1136
|
+
if (has_deletions_) {
|
1137
|
+
top_candidates=searchBaseLayerST<true,true>(
|
1138
|
+
currObj, query_data, std::max(ef_, k));
|
1139
|
+
}
|
1140
|
+
else{
|
1141
|
+
top_candidates=searchBaseLayerST<false,true>(
|
1142
|
+
currObj, query_data, std::max(ef_, k));
|
1143
|
+
}
|
1144
|
+
|
1145
|
+
while (top_candidates.size() > k) {
|
1146
|
+
top_candidates.pop();
|
1147
|
+
}
|
1148
|
+
while (top_candidates.size() > 0) {
|
1149
|
+
std::pair<dist_t, tableint> rez = top_candidates.top();
|
1150
|
+
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
|
1151
|
+
top_candidates.pop();
|
1152
|
+
}
|
1153
|
+
return result;
|
1154
|
+
};
|
1155
|
+
|
1156
|
+
void checkIntegrity(){
|
1157
|
+
int connections_checked=0;
|
1158
|
+
std::vector <int > inbound_connections_num(cur_element_count,0);
|
1159
|
+
for(int i = 0;i < cur_element_count; i++){
|
1160
|
+
for(int l = 0;l <= element_levels_[i]; l++){
|
1161
|
+
linklistsizeint *ll_cur = get_linklist_at_level(i,l);
|
1162
|
+
int size = getListCount(ll_cur);
|
1163
|
+
tableint *data = (tableint *) (ll_cur + 1);
|
1164
|
+
std::unordered_set<tableint> s;
|
1165
|
+
for (int j=0; j<size; j++){
|
1166
|
+
assert(data[j] > 0);
|
1167
|
+
assert(data[j] < cur_element_count);
|
1168
|
+
assert (data[j] != i);
|
1169
|
+
inbound_connections_num[data[j]]++;
|
1170
|
+
s.insert(data[j]);
|
1171
|
+
connections_checked++;
|
1172
|
+
|
1173
|
+
}
|
1174
|
+
assert(s.size() == size);
|
1175
|
+
}
|
1176
|
+
}
|
1177
|
+
if(cur_element_count > 1){
|
1178
|
+
int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
|
1179
|
+
for(int i=0; i < cur_element_count; i++){
|
1180
|
+
assert(inbound_connections_num[i] > 0);
|
1181
|
+
min1=std::min(inbound_connections_num[i],min1);
|
1182
|
+
max1=std::max(inbound_connections_num[i],max1);
|
1183
|
+
}
|
1184
|
+
std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
|
1185
|
+
}
|
1186
|
+
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
|
1187
|
+
|
1188
|
+
}
|
1189
|
+
|
1190
|
+
};
|
1191
|
+
|
1192
|
+
}
|