umappp 0.1.6 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -10,1183 +10,1262 @@
10
10
  #include <list>
11
11
 
12
12
  namespace hnswlib {
13
- typedef unsigned int tableint;
14
- typedef unsigned int linklistsizeint;
13
+ typedef unsigned int tableint;
14
+ typedef unsigned int linklistsizeint;
15
+
16
+ template<typename dist_t>
17
+ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
+ public:
19
+ static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
20
+ static const unsigned char DELETE_MARK = 0x01;
21
+
22
+ size_t max_elements_{0};
23
+ mutable std::atomic<size_t> cur_element_count{0}; // current number of elements
24
+ size_t size_data_per_element_{0};
25
+ size_t size_links_per_element_{0};
26
+ mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements
27
+ size_t M_{0};
28
+ size_t maxM_{0};
29
+ size_t maxM0_{0};
30
+ size_t ef_construction_{0};
31
+ size_t ef_{ 0 };
32
+
33
+ double mult_{0.0}, revSize_{0.0};
34
+ int maxlevel_{0};
35
+
36
+ VisitedListPool *visited_list_pool_{nullptr};
37
+
38
+ // Locks operations with element by label value
39
+ mutable std::vector<std::mutex> label_op_locks_;
40
+
41
+ std::mutex global;
42
+ std::vector<std::mutex> link_list_locks_;
43
+
44
+ tableint enterpoint_node_{0};
45
+
46
+ size_t size_links_level0_{0};
47
+ size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 };
48
+
49
+ char *data_level0_memory_{nullptr};
50
+ char **linkLists_{nullptr};
51
+ std::vector<int> element_levels_; // keeps level of each element
52
+
53
+ size_t data_size_{0};
54
+
55
+ DISTFUNC<dist_t> fstdistfunc_;
56
+ void *dist_func_param_{nullptr};
57
+
58
+ mutable std::mutex label_lookup_lock; // lock for label_lookup_
59
+ std::unordered_map<labeltype, tableint> label_lookup_;
60
+
61
+ std::default_random_engine level_generator_;
62
+ std::default_random_engine update_probability_generator_;
63
+
64
+ mutable std::atomic<long> metric_distance_computations{0};
65
+ mutable std::atomic<long> metric_hops{0};
66
+
67
+ bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions
68
+
69
+ std::mutex deleted_elements_lock; // lock for deleted_elements
70
+ std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
71
+
72
+
73
+ HierarchicalNSW(SpaceInterface<dist_t> *s) {
74
+ }
75
+
76
+
77
+ HierarchicalNSW(
78
+ SpaceInterface<dist_t> *s,
79
+ const std::string &location,
80
+ bool nmslib = false,
81
+ size_t max_elements = 0,
82
+ bool allow_replace_deleted = false)
83
+ : allow_replace_deleted_(allow_replace_deleted) {
84
+ loadIndex(location, s, max_elements);
85
+ }
86
+
87
+
88
+ HierarchicalNSW(
89
+ SpaceInterface<dist_t> *s,
90
+ size_t max_elements,
91
+ size_t M = 16,
92
+ size_t ef_construction = 200,
93
+ size_t random_seed = 100,
94
+ bool allow_replace_deleted = false)
95
+ : label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
96
+ link_list_locks_(max_elements),
97
+ element_levels_(max_elements),
98
+ allow_replace_deleted_(allow_replace_deleted) {
99
+ max_elements_ = max_elements;
100
+ num_deleted_ = 0;
101
+ data_size_ = s->get_data_size();
102
+ fstdistfunc_ = s->get_dist_func();
103
+ dist_func_param_ = s->get_dist_func_param();
104
+ M_ = M;
105
+ maxM_ = M_;
106
+ maxM0_ = M_ * 2;
107
+ ef_construction_ = std::max(ef_construction, M_);
108
+ ef_ = 10;
109
+
110
+ level_generator_.seed(random_seed);
111
+ update_probability_generator_.seed(random_seed + 1);
112
+
113
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
114
+ size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
115
+ offsetData_ = size_links_level0_;
116
+ label_offset_ = size_links_level0_ + data_size_;
117
+ offsetLevel0_ = 0;
118
+
119
+ data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
120
+ if (data_level0_memory_ == nullptr)
121
+ throw std::runtime_error("Not enough memory");
122
+
123
+ cur_element_count = 0;
124
+
125
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
15
126
 
16
- template<typename dist_t>
17
- class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
- public:
19
- static const tableint max_update_element_locks = 65536;
20
- HierarchicalNSW(SpaceInterface<dist_t> *s) {
127
+ // initializations for special treatment of the first node
128
+ enterpoint_node_ = -1;
129
+ maxlevel_ = -1;
21
130
 
22
- }
23
-
24
- HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
25
- loadIndex(location, s, max_elements);
26
- }
27
-
28
- HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
29
- link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
30
- max_elements_ = max_elements;
31
-
32
- has_deletions_=false;
33
- data_size_ = s->get_data_size();
34
- fstdistfunc_ = s->get_dist_func();
35
- dist_func_param_ = s->get_dist_func_param();
36
- M_ = M;
37
- maxM_ = M_;
38
- maxM0_ = M_ * 2;
39
- ef_construction_ = std::max(ef_construction,M_);
40
- ef_ = 10;
41
-
42
- level_generator_.seed(random_seed);
43
- update_probability_generator_.seed(random_seed + 1);
44
-
45
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
46
- size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
47
- offsetData_ = size_links_level0_;
48
- label_offset_ = size_links_level0_ + data_size_;
49
- offsetLevel0_ = 0;
50
-
51
- data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
52
- if (data_level0_memory_ == nullptr)
53
- throw std::runtime_error("Not enough memory");
54
-
55
- cur_element_count = 0;
56
-
57
- visited_list_pool_ = new VisitedListPool(1, max_elements);
131
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
132
+ if (linkLists_ == nullptr)
133
+ throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
134
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
135
+ mult_ = 1 / log(1.0 * M_);
136
+ revSize_ = 1.0 / mult_;
137
+ }
58
138
 
59
139
 
60
-
61
- //initializations for special treatment of the first node
62
- enterpoint_node_ = -1;
63
- maxlevel_ = -1;
64
-
65
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
66
- if (linkLists_ == nullptr)
67
- throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
68
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
69
- mult_ = 1 / log(1.0 * M_);
70
- revSize_ = 1.0 / mult_;
140
+ ~HierarchicalNSW() {
141
+ free(data_level0_memory_);
142
+ for (tableint i = 0; i < cur_element_count; i++) {
143
+ if (element_levels_[i] > 0)
144
+ free(linkLists_[i]);
71
145
  }
146
+ free(linkLists_);
147
+ delete visited_list_pool_;
148
+ }
72
149
 
73
- struct CompareByFirst {
74
- constexpr bool operator()(std::pair<dist_t, tableint> const &a,
75
- std::pair<dist_t, tableint> const &b) const noexcept {
76
- return a.first < b.first;
77
- }
78
- };
79
150
 
80
- ~HierarchicalNSW() {
81
-
82
- free(data_level0_memory_);
83
- for (tableint i = 0; i < cur_element_count; i++) {
84
- if (element_levels_[i] > 0)
85
- free(linkLists_[i]);
86
- }
87
- free(linkLists_);
88
- delete visited_list_pool_;
151
+ struct CompareByFirst {
152
+ constexpr bool operator()(std::pair<dist_t, tableint> const& a,
153
+ std::pair<dist_t, tableint> const& b) const noexcept {
154
+ return a.first < b.first;
89
155
  }
156
+ };
90
157
 
91
- size_t max_elements_;
92
- size_t cur_element_count;
93
- size_t size_data_per_element_;
94
- size_t size_links_per_element_;
95
158
 
96
- size_t M_;
97
- size_t maxM_;
98
- size_t maxM0_;
99
- size_t ef_construction_;
159
+ void setEf(size_t ef) {
160
+ ef_ = ef;
161
+ }
100
162
 
101
- double mult_, revSize_;
102
- int maxlevel_;
103
163
 
164
+ inline std::mutex& getLabelOpMutex(labeltype label) const {
165
+ // calculate hash
166
+ size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
167
+ return label_op_locks_[lock_id];
168
+ }
104
169
 
105
- VisitedListPool *visited_list_pool_;
106
- std::mutex cur_element_count_guard_;
107
170
 
108
- std::vector<std::mutex> link_list_locks_;
171
+ inline labeltype getExternalLabel(tableint internal_id) const {
172
+ labeltype return_label;
173
+ memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
174
+ return return_label;
175
+ }
109
176
 
110
- // Locks to prevent race condition during update/insert of an element at same time.
111
- // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
112
- std::vector<std::mutex> link_list_update_locks_;
113
- tableint enterpoint_node_;
114
177
 
178
+ inline void setExternalLabel(tableint internal_id, labeltype label) const {
179
+ memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
180
+ }
115
181
 
116
- size_t size_links_level0_;
117
- size_t offsetData_, offsetLevel0_;
118
182
 
183
+ inline labeltype *getExternalLabeLp(tableint internal_id) const {
184
+ return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
185
+ }
119
186
 
120
- char *data_level0_memory_;
121
- char **linkLists_;
122
- std::vector<int> element_levels_;
123
187
 
124
- size_t data_size_;
188
+ inline char *getDataByInternalId(tableint internal_id) const {
189
+ return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
190
+ }
125
191
 
126
- bool has_deletions_;
127
192
 
193
+ int getRandomLevel(double reverse_size) {
194
+ std::uniform_real_distribution<double> distribution(0.0, 1.0);
195
+ double r = -log(distribution(level_generator_)) * reverse_size;
196
+ return (int) r;
197
+ }
128
198
 
129
- size_t label_offset_;
130
- DISTFUNC<dist_t> fstdistfunc_;
131
- void *dist_func_param_;
132
- std::unordered_map<labeltype, tableint> label_lookup_;
199
+ size_t getMaxElements() {
200
+ return max_elements_;
201
+ }
133
202
 
134
- std::default_random_engine level_generator_;
135
- std::default_random_engine update_probability_generator_;
203
+ size_t getCurrentElementCount() {
204
+ return cur_element_count;
205
+ }
136
206
 
137
- inline labeltype getExternalLabel(tableint internal_id) const {
138
- labeltype return_label;
139
- memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
140
- return return_label;
141
- }
207
+ size_t getDeletedCount() {
208
+ return num_deleted_;
209
+ }
142
210
 
143
- inline void setExternalLabel(tableint internal_id, labeltype label) const {
144
- memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
145
- }
211
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
212
+ searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
213
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
214
+ vl_type *visited_array = vl->mass;
215
+ vl_type visited_array_tag = vl->curV;
146
216
 
147
- inline labeltype *getExternalLabeLp(tableint internal_id) const {
148
- return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
149
- }
217
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
218
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
150
219
 
151
- inline char *getDataByInternalId(tableint internal_id) const {
152
- return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
220
+ dist_t lowerBound;
221
+ if (!isMarkedDeleted(ep_id)) {
222
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
223
+ top_candidates.emplace(dist, ep_id);
224
+ lowerBound = dist;
225
+ candidateSet.emplace(-dist, ep_id);
226
+ } else {
227
+ lowerBound = std::numeric_limits<dist_t>::max();
228
+ candidateSet.emplace(-lowerBound, ep_id);
153
229
  }
230
+ visited_array[ep_id] = visited_array_tag;
154
231
 
155
- int getRandomLevel(double reverse_size) {
156
- std::uniform_real_distribution<double> distribution(0.0, 1.0);
157
- double r = -log(distribution(level_generator_)) * reverse_size;
158
- return (int) r;
159
- }
160
-
161
-
162
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
163
- searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
164
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
165
- vl_type *visited_array = vl->mass;
166
- vl_type visited_array_tag = vl->curV;
167
-
168
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
169
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
170
-
171
- dist_t lowerBound;
172
- if (!isMarkedDeleted(ep_id)) {
173
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
174
- top_candidates.emplace(dist, ep_id);
175
- lowerBound = dist;
176
- candidateSet.emplace(-dist, ep_id);
177
- } else {
178
- lowerBound = std::numeric_limits<dist_t>::max();
179
- candidateSet.emplace(-lowerBound, ep_id);
232
+ while (!candidateSet.empty()) {
233
+ std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
234
+ if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
235
+ break;
180
236
  }
181
- visited_array[ep_id] = visited_array_tag;
182
-
183
- while (!candidateSet.empty()) {
184
- std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
185
- if ((-curr_el_pair.first) > lowerBound) {
186
- break;
187
- }
188
- candidateSet.pop();
237
+ candidateSet.pop();
189
238
 
190
- tableint curNodeNum = curr_el_pair.second;
239
+ tableint curNodeNum = curr_el_pair.second;
191
240
 
192
- std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
241
+ std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
193
242
 
194
- int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
195
- if (layer == 0) {
196
- data = (int*)get_linklist0(curNodeNum);
197
- } else {
198
- data = (int*)get_linklist(curNodeNum, layer);
243
+ int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
244
+ if (layer == 0) {
245
+ data = (int*)get_linklist0(curNodeNum);
246
+ } else {
247
+ data = (int*)get_linklist(curNodeNum, layer);
199
248
  // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
200
- }
201
- size_t size = getListCount((linklistsizeint*)data);
202
- tableint *datal = (tableint *) (data + 1);
249
+ }
250
+ size_t size = getListCount((linklistsizeint*)data);
251
+ tableint *datal = (tableint *) (data + 1);
203
252
  #ifdef USE_SSE
204
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
205
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
206
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
207
- _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
253
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
254
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
255
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
256
+ _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
208
257
  #endif
209
258
 
210
- for (size_t j = 0; j < size; j++) {
211
- tableint candidate_id = *(datal + j);
259
+ for (size_t j = 0; j < size; j++) {
260
+ tableint candidate_id = *(datal + j);
212
261
  // if (candidate_id == 0) continue;
213
262
  #ifdef USE_SSE
214
- _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
215
- _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
263
+ _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
264
+ _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
216
265
  #endif
217
- if (visited_array[candidate_id] == visited_array_tag) continue;
218
- visited_array[candidate_id] = visited_array_tag;
219
- char *currObj1 = (getDataByInternalId(candidate_id));
266
+ if (visited_array[candidate_id] == visited_array_tag) continue;
267
+ visited_array[candidate_id] = visited_array_tag;
268
+ char *currObj1 = (getDataByInternalId(candidate_id));
220
269
 
221
- dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
222
- if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
223
- candidateSet.emplace(-dist1, candidate_id);
270
+ dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
271
+ if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
272
+ candidateSet.emplace(-dist1, candidate_id);
224
273
  #ifdef USE_SSE
225
- _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
274
+ _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
226
275
  #endif
227
276
 
228
- if (!isMarkedDeleted(candidate_id))
229
- top_candidates.emplace(dist1, candidate_id);
277
+ if (!isMarkedDeleted(candidate_id))
278
+ top_candidates.emplace(dist1, candidate_id);
230
279
 
231
- if (top_candidates.size() > ef_construction_)
232
- top_candidates.pop();
280
+ if (top_candidates.size() > ef_construction_)
281
+ top_candidates.pop();
233
282
 
234
- if (!top_candidates.empty())
235
- lowerBound = top_candidates.top().first;
236
- }
283
+ if (!top_candidates.empty())
284
+ lowerBound = top_candidates.top().first;
237
285
  }
238
286
  }
239
- visited_list_pool_->releaseVisitedList(vl);
240
-
241
- return top_candidates;
287
+ }
288
+ visited_list_pool_->releaseVisitedList(vl);
289
+
290
+ return top_candidates;
291
+ }
292
+
293
+
294
+ template <bool has_deletions, bool collect_metrics = false>
295
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
296
+ searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
297
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
298
+ vl_type *visited_array = vl->mass;
299
+ vl_type visited_array_tag = vl->curV;
300
+
301
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
302
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
303
+
304
+ dist_t lowerBound;
305
+ if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
306
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
307
+ lowerBound = dist;
308
+ top_candidates.emplace(dist, ep_id);
309
+ candidate_set.emplace(-dist, ep_id);
310
+ } else {
311
+ lowerBound = std::numeric_limits<dist_t>::max();
312
+ candidate_set.emplace(-lowerBound, ep_id);
242
313
  }
243
314
 
244
- mutable std::atomic<long> metric_distance_computations;
245
- mutable std::atomic<long> metric_hops;
246
-
247
- template <bool has_deletions, bool collect_metrics=false>
248
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
249
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
250
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
251
- vl_type *visited_array = vl->mass;
252
- vl_type visited_array_tag = vl->curV;
253
-
254
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
255
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
256
-
257
- dist_t lowerBound;
258
- if (!has_deletions || !isMarkedDeleted(ep_id)) {
259
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
260
- lowerBound = dist;
261
- top_candidates.emplace(dist, ep_id);
262
- candidate_set.emplace(-dist, ep_id);
263
- } else {
264
- lowerBound = std::numeric_limits<dist_t>::max();
265
- candidate_set.emplace(-lowerBound, ep_id);
266
- }
267
-
268
- visited_array[ep_id] = visited_array_tag;
269
-
270
- while (!candidate_set.empty()) {
315
+ visited_array[ep_id] = visited_array_tag;
271
316
 
272
- std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
317
+ while (!candidate_set.empty()) {
318
+ std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
273
319
 
274
- if ((-current_node_pair.first) > lowerBound) {
275
- break;
276
- }
277
- candidate_set.pop();
320
+ if ((-current_node_pair.first) > lowerBound &&
321
+ (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
322
+ break;
323
+ }
324
+ candidate_set.pop();
278
325
 
279
- tableint current_node_id = current_node_pair.second;
280
- int *data = (int *) get_linklist0(current_node_id);
281
- size_t size = getListCount((linklistsizeint*)data);
326
+ tableint current_node_id = current_node_pair.second;
327
+ int *data = (int *) get_linklist0(current_node_id);
328
+ size_t size = getListCount((linklistsizeint*)data);
282
329
  // bool cur_node_deleted = isMarkedDeleted(current_node_id);
283
- if(collect_metrics){
284
- metric_hops++;
285
- metric_distance_computations+=size;
286
- }
330
+ if (collect_metrics) {
331
+ metric_hops++;
332
+ metric_distance_computations+=size;
333
+ }
287
334
 
288
335
  #ifdef USE_SSE
289
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
290
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
291
- _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
292
- _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
336
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
337
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
338
+ _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
339
+ _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
293
340
  #endif
294
341
 
295
- for (size_t j = 1; j <= size; j++) {
296
- int candidate_id = *(data + j);
342
+ for (size_t j = 1; j <= size; j++) {
343
+ int candidate_id = *(data + j);
297
344
  // if (candidate_id == 0) continue;
298
345
  #ifdef USE_SSE
299
- _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
300
- _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
301
- _MM_HINT_T0);////////////
346
+ _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
347
+ _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
348
+ _MM_HINT_T0); ////////////
302
349
  #endif
303
- if (!(visited_array[candidate_id] == visited_array_tag)) {
304
-
305
- visited_array[candidate_id] = visited_array_tag;
350
+ if (!(visited_array[candidate_id] == visited_array_tag)) {
351
+ visited_array[candidate_id] = visited_array_tag;
306
352
 
307
- char *currObj1 = (getDataByInternalId(candidate_id));
308
- dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
353
+ char *currObj1 = (getDataByInternalId(candidate_id));
354
+ dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
309
355
 
310
- if (top_candidates.size() < ef || lowerBound > dist) {
311
- candidate_set.emplace(-dist, candidate_id);
356
+ if (top_candidates.size() < ef || lowerBound > dist) {
357
+ candidate_set.emplace(-dist, candidate_id);
312
358
  #ifdef USE_SSE
313
- _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
314
- offsetLevel0_,///////////
315
- _MM_HINT_T0);////////////////////////
359
+ _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
360
+ offsetLevel0_, ///////////
361
+ _MM_HINT_T0); ////////////////////////
316
362
  #endif
317
363
 
318
- if (!has_deletions || !isMarkedDeleted(candidate_id))
319
- top_candidates.emplace(dist, candidate_id);
364
+ if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
365
+ top_candidates.emplace(dist, candidate_id);
320
366
 
321
- if (top_candidates.size() > ef)
322
- top_candidates.pop();
367
+ if (top_candidates.size() > ef)
368
+ top_candidates.pop();
323
369
 
324
- if (!top_candidates.empty())
325
- lowerBound = top_candidates.top().first;
326
- }
370
+ if (!top_candidates.empty())
371
+ lowerBound = top_candidates.top().first;
327
372
  }
328
373
  }
329
374
  }
330
-
331
- visited_list_pool_->releaseVisitedList(vl);
332
- return top_candidates;
333
375
  }
334
376
 
335
- void getNeighborsByHeuristic2(
336
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
337
- const size_t M) {
338
- if (top_candidates.size() < M) {
339
- return;
340
- }
377
+ visited_list_pool_->releaseVisitedList(vl);
378
+ return top_candidates;
379
+ }
341
380
 
342
- std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
343
- std::vector<std::pair<dist_t, tableint>> return_list;
344
- while (top_candidates.size() > 0) {
345
- queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
346
- top_candidates.pop();
347
- }
348
381
 
349
- while (queue_closest.size()) {
350
- if (return_list.size() >= M)
382
+ void getNeighborsByHeuristic2(
383
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
384
+ const size_t M) {
385
+ if (top_candidates.size() < M) {
386
+ return;
387
+ }
388
+
389
+ std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
390
+ std::vector<std::pair<dist_t, tableint>> return_list;
391
+ while (top_candidates.size() > 0) {
392
+ queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
393
+ top_candidates.pop();
394
+ }
395
+
396
+ while (queue_closest.size()) {
397
+ if (return_list.size() >= M)
398
+ break;
399
+ std::pair<dist_t, tableint> curent_pair = queue_closest.top();
400
+ dist_t dist_to_query = -curent_pair.first;
401
+ queue_closest.pop();
402
+ bool good = true;
403
+
404
+ for (std::pair<dist_t, tableint> second_pair : return_list) {
405
+ dist_t curdist =
406
+ fstdistfunc_(getDataByInternalId(second_pair.second),
407
+ getDataByInternalId(curent_pair.second),
408
+ dist_func_param_);
409
+ if (curdist < dist_to_query) {
410
+ good = false;
351
411
  break;
352
- std::pair<dist_t, tableint> curent_pair = queue_closest.top();
353
- dist_t dist_to_query = -curent_pair.first;
354
- queue_closest.pop();
355
- bool good = true;
356
-
357
- for (std::pair<dist_t, tableint> second_pair : return_list) {
358
- dist_t curdist =
359
- fstdistfunc_(getDataByInternalId(second_pair.second),
360
- getDataByInternalId(curent_pair.second),
361
- dist_func_param_);;
362
- if (curdist < dist_to_query) {
363
- good = false;
364
- break;
365
- }
366
- }
367
- if (good) {
368
- return_list.push_back(curent_pair);
369
412
  }
370
413
  }
371
-
372
- for (std::pair<dist_t, tableint> curent_pair : return_list) {
373
- top_candidates.emplace(-curent_pair.first, curent_pair.second);
414
+ if (good) {
415
+ return_list.push_back(curent_pair);
374
416
  }
375
417
  }
376
418
 
419
+ for (std::pair<dist_t, tableint> curent_pair : return_list) {
420
+ top_candidates.emplace(-curent_pair.first, curent_pair.second);
421
+ }
422
+ }
377
423
 
378
- linklistsizeint *get_linklist0(tableint internal_id) const {
379
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
380
- };
381
424
 
382
- linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
383
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
384
- };
425
+ linklistsizeint *get_linklist0(tableint internal_id) const {
426
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
427
+ }
385
428
 
386
- linklistsizeint *get_linklist(tableint internal_id, int level) const {
387
- return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
388
- };
389
429
 
390
- linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
391
- return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
392
- };
430
+ linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
431
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
432
+ }
393
433
 
394
- tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
395
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
396
- int level, bool isUpdate) {
397
- size_t Mcurmax = level ? maxM_ : maxM0_;
398
- getNeighborsByHeuristic2(top_candidates, M_);
399
- if (top_candidates.size() > M_)
400
- throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
401
434
 
402
- std::vector<tableint> selectedNeighbors;
403
- selectedNeighbors.reserve(M_);
404
- while (top_candidates.size() > 0) {
405
- selectedNeighbors.push_back(top_candidates.top().second);
406
- top_candidates.pop();
407
- }
435
+ linklistsizeint *get_linklist(tableint internal_id, int level) const {
436
+ return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
437
+ }
408
438
 
409
- tableint next_closest_entry_point = selectedNeighbors.back();
410
439
 
411
- {
412
- linklistsizeint *ll_cur;
413
- if (level == 0)
414
- ll_cur = get_linklist0(cur_c);
415
- else
416
- ll_cur = get_linklist(cur_c, level);
440
+ linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
441
+ return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
442
+ }
417
443
 
418
- if (*ll_cur && !isUpdate) {
419
- throw std::runtime_error("The newly inserted element should have blank link list");
420
- }
421
- setListCount(ll_cur,selectedNeighbors.size());
422
- tableint *data = (tableint *) (ll_cur + 1);
423
- for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
424
- if (data[idx] && !isUpdate)
425
- throw std::runtime_error("Possible memory corruption");
426
- if (level > element_levels_[selectedNeighbors[idx]])
427
- throw std::runtime_error("Trying to make a link on a non-existent level");
428
444
 
429
- data[idx] = selectedNeighbors[idx];
445
+ tableint mutuallyConnectNewElement(
446
+ const void *data_point,
447
+ tableint cur_c,
448
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
449
+ int level,
450
+ bool isUpdate) {
451
+ size_t Mcurmax = level ? maxM_ : maxM0_;
452
+ getNeighborsByHeuristic2(top_candidates, M_);
453
+ if (top_candidates.size() > M_)
454
+ throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
430
455
 
431
- }
432
- }
456
+ std::vector<tableint> selectedNeighbors;
457
+ selectedNeighbors.reserve(M_);
458
+ while (top_candidates.size() > 0) {
459
+ selectedNeighbors.push_back(top_candidates.top().second);
460
+ top_candidates.pop();
461
+ }
462
+
463
+ tableint next_closest_entry_point = selectedNeighbors.back();
433
464
 
465
+ {
466
+ // lock only during the update
467
+ // because during the addition the lock for cur_c is already acquired
468
+ std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
469
+ if (isUpdate) {
470
+ lock.lock();
471
+ }
472
+ linklistsizeint *ll_cur;
473
+ if (level == 0)
474
+ ll_cur = get_linklist0(cur_c);
475
+ else
476
+ ll_cur = get_linklist(cur_c, level);
477
+
478
+ if (*ll_cur && !isUpdate) {
479
+ throw std::runtime_error("The newly inserted element should have blank link list");
480
+ }
481
+ setListCount(ll_cur, selectedNeighbors.size());
482
+ tableint *data = (tableint *) (ll_cur + 1);
434
483
  for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
484
+ if (data[idx] && !isUpdate)
485
+ throw std::runtime_error("Possible memory corruption");
486
+ if (level > element_levels_[selectedNeighbors[idx]])
487
+ throw std::runtime_error("Trying to make a link on a non-existent level");
435
488
 
436
- std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
489
+ data[idx] = selectedNeighbors[idx];
490
+ }
491
+ }
437
492
 
438
- linklistsizeint *ll_other;
439
- if (level == 0)
440
- ll_other = get_linklist0(selectedNeighbors[idx]);
441
- else
442
- ll_other = get_linklist(selectedNeighbors[idx], level);
493
+ for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
494
+ std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
443
495
 
444
- size_t sz_link_list_other = getListCount(ll_other);
496
+ linklistsizeint *ll_other;
497
+ if (level == 0)
498
+ ll_other = get_linklist0(selectedNeighbors[idx]);
499
+ else
500
+ ll_other = get_linklist(selectedNeighbors[idx], level);
445
501
 
446
- if (sz_link_list_other > Mcurmax)
447
- throw std::runtime_error("Bad value of sz_link_list_other");
448
- if (selectedNeighbors[idx] == cur_c)
449
- throw std::runtime_error("Trying to connect an element to itself");
450
- if (level > element_levels_[selectedNeighbors[idx]])
451
- throw std::runtime_error("Trying to make a link on a non-existent level");
502
+ size_t sz_link_list_other = getListCount(ll_other);
452
503
 
453
- tableint *data = (tableint *) (ll_other + 1);
504
+ if (sz_link_list_other > Mcurmax)
505
+ throw std::runtime_error("Bad value of sz_link_list_other");
506
+ if (selectedNeighbors[idx] == cur_c)
507
+ throw std::runtime_error("Trying to connect an element to itself");
508
+ if (level > element_levels_[selectedNeighbors[idx]])
509
+ throw std::runtime_error("Trying to make a link on a non-existent level");
454
510
 
455
- bool is_cur_c_present = false;
456
- if (isUpdate) {
457
- for (size_t j = 0; j < sz_link_list_other; j++) {
458
- if (data[j] == cur_c) {
459
- is_cur_c_present = true;
460
- break;
461
- }
511
+ tableint *data = (tableint *) (ll_other + 1);
512
+
513
+ bool is_cur_c_present = false;
514
+ if (isUpdate) {
515
+ for (size_t j = 0; j < sz_link_list_other; j++) {
516
+ if (data[j] == cur_c) {
517
+ is_cur_c_present = true;
518
+ break;
462
519
  }
463
520
  }
521
+ }
464
522
 
465
- // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
466
- if (!is_cur_c_present) {
467
- if (sz_link_list_other < Mcurmax) {
468
- data[sz_link_list_other] = cur_c;
469
- setListCount(ll_other, sz_link_list_other + 1);
470
- } else {
471
- // finding the "weakest" element to replace it with the new one
472
- dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
473
- dist_func_param_);
474
- // Heuristic:
475
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
476
- candidates.emplace(d_max, cur_c);
477
-
478
- for (size_t j = 0; j < sz_link_list_other; j++) {
479
- candidates.emplace(
480
- fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
481
- dist_func_param_), data[j]);
482
- }
523
+ // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
524
+ if (!is_cur_c_present) {
525
+ if (sz_link_list_other < Mcurmax) {
526
+ data[sz_link_list_other] = cur_c;
527
+ setListCount(ll_other, sz_link_list_other + 1);
528
+ } else {
529
+ // finding the "weakest" element to replace it with the new one
530
+ dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
531
+ dist_func_param_);
532
+ // Heuristic:
533
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
534
+ candidates.emplace(d_max, cur_c);
483
535
 
484
- getNeighborsByHeuristic2(candidates, Mcurmax);
536
+ for (size_t j = 0; j < sz_link_list_other; j++) {
537
+ candidates.emplace(
538
+ fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
539
+ dist_func_param_), data[j]);
540
+ }
485
541
 
486
- int indx = 0;
487
- while (candidates.size() > 0) {
488
- data[indx] = candidates.top().second;
489
- candidates.pop();
490
- indx++;
491
- }
542
+ getNeighborsByHeuristic2(candidates, Mcurmax);
492
543
 
493
- setListCount(ll_other, indx);
494
- // Nearest K:
495
- /*int indx = -1;
496
- for (int j = 0; j < sz_link_list_other; j++) {
497
- dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
498
- if (d > d_max) {
499
- indx = j;
500
- d_max = d;
501
- }
544
+ int indx = 0;
545
+ while (candidates.size() > 0) {
546
+ data[indx] = candidates.top().second;
547
+ candidates.pop();
548
+ indx++;
549
+ }
550
+
551
+ setListCount(ll_other, indx);
552
+ // Nearest K:
553
+ /*int indx = -1;
554
+ for (int j = 0; j < sz_link_list_other; j++) {
555
+ dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
556
+ if (d > d_max) {
557
+ indx = j;
558
+ d_max = d;
502
559
  }
503
- if (indx >= 0) {
504
- data[indx] = cur_c;
505
- } */
506
560
  }
561
+ if (indx >= 0) {
562
+ data[indx] = cur_c;
563
+ } */
507
564
  }
508
565
  }
509
-
510
- return next_closest_entry_point;
511
566
  }
512
567
 
513
- std::mutex global;
514
- size_t ef_;
515
-
516
- void setEf(size_t ef) {
517
- ef_ = ef;
518
- }
568
+ return next_closest_entry_point;
569
+ }
519
570
 
520
571
 
521
- std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
522
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
523
- if (cur_element_count == 0) return top_candidates;
524
- tableint currObj = enterpoint_node_;
525
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
526
-
527
- for (size_t level = maxlevel_; level > 0; level--) {
528
- bool changed = true;
529
- while (changed) {
530
- changed = false;
531
- int *data;
532
- data = (int *) get_linklist(currObj,level);
533
- int size = getListCount(data);
534
- tableint *datal = (tableint *) (data + 1);
535
- for (int i = 0; i < size; i++) {
536
- tableint cand = datal[i];
537
- if (cand < 0 || cand > max_elements_)
538
- throw std::runtime_error("cand error");
539
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
572
+ void resizeIndex(size_t new_max_elements) {
573
+ if (new_max_elements < cur_element_count)
574
+ throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
540
575
 
541
- if (d < curdist) {
542
- curdist = d;
543
- currObj = cand;
544
- changed = true;
545
- }
546
- }
547
- }
548
- }
576
+ delete visited_list_pool_;
577
+ visited_list_pool_ = new VisitedListPool(1, new_max_elements);
549
578
 
550
- if (has_deletions_) {
551
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
552
- ef_);
553
- top_candidates.swap(top_candidates1);
554
- }
555
- else{
556
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
557
- ef_);
558
- top_candidates.swap(top_candidates1);
559
- }
579
+ element_levels_.resize(new_max_elements);
560
580
 
561
- while (top_candidates.size() > k) {
562
- top_candidates.pop();
563
- }
564
- return top_candidates;
565
- };
581
+ std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
566
582
 
567
- void resizeIndex(size_t new_max_elements){
568
- if (new_max_elements<cur_element_count)
569
- throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
583
+ // Reallocate base layer
584
+ char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
585
+ if (data_level0_memory_new == nullptr)
586
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
587
+ data_level0_memory_ = data_level0_memory_new;
570
588
 
589
+ // Reallocate all other layers
590
+ char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
591
+ if (linkLists_new == nullptr)
592
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
593
+ linkLists_ = linkLists_new;
571
594
 
572
- delete visited_list_pool_;
573
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
595
+ max_elements_ = new_max_elements;
596
+ }
574
597
 
575
598
 
576
- element_levels_.resize(new_max_elements);
599
+ void saveIndex(const std::string &location) {
600
+ std::ofstream output(location, std::ios::binary);
601
+ std::streampos position;
577
602
 
578
- std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
603
+ writeBinaryPOD(output, offsetLevel0_);
604
+ writeBinaryPOD(output, max_elements_);
605
+ writeBinaryPOD(output, cur_element_count);
606
+ writeBinaryPOD(output, size_data_per_element_);
607
+ writeBinaryPOD(output, label_offset_);
608
+ writeBinaryPOD(output, offsetData_);
609
+ writeBinaryPOD(output, maxlevel_);
610
+ writeBinaryPOD(output, enterpoint_node_);
611
+ writeBinaryPOD(output, maxM_);
579
612
 
580
- // Reallocate base layer
581
- char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
582
- if (data_level0_memory_new == nullptr)
583
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
584
- data_level0_memory_ = data_level0_memory_new;
613
+ writeBinaryPOD(output, maxM0_);
614
+ writeBinaryPOD(output, M_);
615
+ writeBinaryPOD(output, mult_);
616
+ writeBinaryPOD(output, ef_construction_);
585
617
 
586
- // Reallocate all other layers
587
- char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
588
- if (linkLists_new == nullptr)
589
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
590
- linkLists_ = linkLists_new;
618
+ output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
591
619
 
592
- max_elements_ = new_max_elements;
620
+ for (size_t i = 0; i < cur_element_count; i++) {
621
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
622
+ writeBinaryPOD(output, linkListSize);
623
+ if (linkListSize)
624
+ output.write(linkLists_[i], linkListSize);
593
625
  }
626
+ output.close();
627
+ }
628
+
629
+
630
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0) {
631
+ std::ifstream input(location, std::ios::binary);
632
+
633
+ if (!input.is_open())
634
+ throw std::runtime_error("Cannot open file");
635
+
636
+ // get file size:
637
+ input.seekg(0, input.end);
638
+ std::streampos total_filesize = input.tellg();
639
+ input.seekg(0, input.beg);
640
+
641
+ readBinaryPOD(input, offsetLevel0_);
642
+ readBinaryPOD(input, max_elements_);
643
+ readBinaryPOD(input, cur_element_count);
644
+
645
+ size_t max_elements = max_elements_i;
646
+ if (max_elements < cur_element_count)
647
+ max_elements = max_elements_;
648
+ max_elements_ = max_elements;
649
+ readBinaryPOD(input, size_data_per_element_);
650
+ readBinaryPOD(input, label_offset_);
651
+ readBinaryPOD(input, offsetData_);
652
+ readBinaryPOD(input, maxlevel_);
653
+ readBinaryPOD(input, enterpoint_node_);
654
+
655
+ readBinaryPOD(input, maxM_);
656
+ readBinaryPOD(input, maxM0_);
657
+ readBinaryPOD(input, M_);
658
+ readBinaryPOD(input, mult_);
659
+ readBinaryPOD(input, ef_construction_);
660
+
661
+ data_size_ = s->get_data_size();
662
+ fstdistfunc_ = s->get_dist_func();
663
+ dist_func_param_ = s->get_dist_func_param();
664
+
665
+ auto pos = input.tellg();
666
+
667
+ /// Optional - check if index is ok:
668
+ input.seekg(cur_element_count * size_data_per_element_, input.cur);
669
+ for (size_t i = 0; i < cur_element_count; i++) {
670
+ if (input.tellg() < 0 || input.tellg() >= total_filesize) {
671
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
672
+ }
594
673
 
595
- void saveIndex(const std::string &location) {
596
- std::ofstream output(location, std::ios::binary);
597
- std::streampos position;
598
-
599
- writeBinaryPOD(output, offsetLevel0_);
600
- writeBinaryPOD(output, max_elements_);
601
- writeBinaryPOD(output, cur_element_count);
602
- writeBinaryPOD(output, size_data_per_element_);
603
- writeBinaryPOD(output, label_offset_);
604
- writeBinaryPOD(output, offsetData_);
605
- writeBinaryPOD(output, maxlevel_);
606
- writeBinaryPOD(output, enterpoint_node_);
607
- writeBinaryPOD(output, maxM_);
608
-
609
- writeBinaryPOD(output, maxM0_);
610
- writeBinaryPOD(output, M_);
611
- writeBinaryPOD(output, mult_);
612
- writeBinaryPOD(output, ef_construction_);
613
-
614
- output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
615
-
616
- for (size_t i = 0; i < cur_element_count; i++) {
617
- unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
618
- writeBinaryPOD(output, linkListSize);
619
- if (linkListSize)
620
- output.write(linkLists_[i], linkListSize);
674
+ unsigned int linkListSize;
675
+ readBinaryPOD(input, linkListSize);
676
+ if (linkListSize != 0) {
677
+ input.seekg(linkListSize, input.cur);
621
678
  }
622
- output.close();
623
679
  }
624
680
 
625
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
626
-
627
-
628
- std::ifstream input(location, std::ios::binary);
629
-
630
- if (!input.is_open())
631
- throw std::runtime_error("Cannot open file");
681
+ // throw exception if it either corrupted or old index
682
+ if (input.tellg() != total_filesize)
683
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
684
+
685
+ input.clear();
686
+ /// Optional check end
687
+
688
+ input.seekg(pos, input.beg);
689
+
690
+ data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
691
+ if (data_level0_memory_ == nullptr)
692
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
693
+ input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
694
+
695
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
696
+
697
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
698
+ std::vector<std::mutex>(max_elements).swap(link_list_locks_);
699
+ std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
700
+
701
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
702
+
703
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
704
+ if (linkLists_ == nullptr)
705
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
706
+ element_levels_ = std::vector<int>(max_elements);
707
+ revSize_ = 1.0 / mult_;
708
+ ef_ = 10;
709
+ for (size_t i = 0; i < cur_element_count; i++) {
710
+ label_lookup_[getExternalLabel(i)] = i;
711
+ unsigned int linkListSize;
712
+ readBinaryPOD(input, linkListSize);
713
+ if (linkListSize == 0) {
714
+ element_levels_[i] = 0;
715
+ linkLists_[i] = nullptr;
716
+ } else {
717
+ element_levels_[i] = linkListSize / size_links_per_element_;
718
+ linkLists_[i] = (char *) malloc(linkListSize);
719
+ if (linkLists_[i] == nullptr)
720
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
721
+ input.read(linkLists_[i], linkListSize);
722
+ }
723
+ }
632
724
 
633
- // get file size:
634
- input.seekg(0,input.end);
635
- std::streampos total_filesize=input.tellg();
636
- input.seekg(0,input.beg);
725
+ for (size_t i = 0; i < cur_element_count; i++) {
726
+ if (isMarkedDeleted(i)) {
727
+ num_deleted_ += 1;
728
+ if (allow_replace_deleted_) deleted_elements.insert(i);
729
+ }
730
+ }
637
731
 
638
- readBinaryPOD(input, offsetLevel0_);
639
- readBinaryPOD(input, max_elements_);
640
- readBinaryPOD(input, cur_element_count);
732
+ input.close();
641
733
 
642
- size_t max_elements=max_elements_i;
643
- if(max_elements < cur_element_count)
644
- max_elements = max_elements_;
645
- max_elements_ = max_elements;
646
- readBinaryPOD(input, size_data_per_element_);
647
- readBinaryPOD(input, label_offset_);
648
- readBinaryPOD(input, offsetData_);
649
- readBinaryPOD(input, maxlevel_);
650
- readBinaryPOD(input, enterpoint_node_);
734
+ return;
735
+ }
651
736
 
652
- readBinaryPOD(input, maxM_);
653
- readBinaryPOD(input, maxM0_);
654
- readBinaryPOD(input, M_);
655
- readBinaryPOD(input, mult_);
656
- readBinaryPOD(input, ef_construction_);
657
737
 
738
+ template<typename data_t>
739
+ std::vector<data_t> getDataByLabel(labeltype label) const {
740
+ // lock all operations with element by label
741
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
742
+
743
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
744
+ auto search = label_lookup_.find(label);
745
+ if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
746
+ throw std::runtime_error("Label not found");
747
+ }
748
+ tableint internalId = search->second;
749
+ lock_table.unlock();
750
+
751
+ char* data_ptrv = getDataByInternalId(internalId);
752
+ size_t dim = *((size_t *) dist_func_param_);
753
+ std::vector<data_t> data;
754
+ data_t* data_ptr = (data_t*) data_ptrv;
755
+ for (int i = 0; i < dim; i++) {
756
+ data.push_back(*data_ptr);
757
+ data_ptr += 1;
758
+ }
759
+ return data;
760
+ }
658
761
 
659
- data_size_ = s->get_data_size();
660
- fstdistfunc_ = s->get_dist_func();
661
- dist_func_param_ = s->get_dist_func_param();
662
762
 
663
- auto pos=input.tellg();
763
+ /*
764
+ * Marks an element with the given label deleted, does NOT really change the current graph.
765
+ */
766
+ void markDelete(labeltype label) {
767
+ // lock all operations with element by label
768
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
664
769
 
770
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
771
+ auto search = label_lookup_.find(label);
772
+ if (search == label_lookup_.end()) {
773
+ throw std::runtime_error("Label not found");
774
+ }
775
+ tableint internalId = search->second;
776
+ lock_table.unlock();
665
777
 
666
- /// Optional - check if index is ok:
778
+ markDeletedInternal(internalId);
779
+ }
667
780
 
668
- input.seekg(cur_element_count * size_data_per_element_,input.cur);
669
- for (size_t i = 0; i < cur_element_count; i++) {
670
- if(input.tellg() < 0 || input.tellg()>=total_filesize){
671
- throw std::runtime_error("Index seems to be corrupted or unsupported");
672
- }
673
781
 
674
- unsigned int linkListSize;
675
- readBinaryPOD(input, linkListSize);
676
- if (linkListSize != 0) {
677
- input.seekg(linkListSize,input.cur);
678
- }
782
+ /*
783
+ * Uses the last 16 bits of the memory for the linked list size to store the mark,
784
+ * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases.
785
+ */
786
+ void markDeletedInternal(tableint internalId) {
787
+ assert(internalId < cur_element_count);
788
+ if (!isMarkedDeleted(internalId)) {
789
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
790
+ *ll_cur |= DELETE_MARK;
791
+ num_deleted_ += 1;
792
+ if (allow_replace_deleted_) {
793
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
794
+ deleted_elements.insert(internalId);
679
795
  }
796
+ } else {
797
+ throw std::runtime_error("The requested to delete element is already deleted");
798
+ }
799
+ }
800
+
801
+
802
+ /*
803
+ * Removes the deleted mark of the node, does NOT really change the current graph.
804
+ *
805
+ * Note: the method is not safe to use when replacement of deleted elements is enabled,
806
+ * because elements marked as deleted can be completely removed by addPoint
807
+ */
808
+ void unmarkDelete(labeltype label) {
809
+ // lock all operations with element by label
810
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
811
+
812
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
813
+ auto search = label_lookup_.find(label);
814
+ if (search == label_lookup_.end()) {
815
+ throw std::runtime_error("Label not found");
816
+ }
817
+ tableint internalId = search->second;
818
+ lock_table.unlock();
680
819
 
681
- // throw exception if it either corrupted or old index
682
- if(input.tellg()!=total_filesize)
683
- throw std::runtime_error("Index seems to be corrupted or unsupported");
684
-
685
- input.clear();
686
-
687
- /// Optional check end
688
-
689
- input.seekg(pos,input.beg);
690
-
691
-
692
- data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
693
- if (data_level0_memory_ == nullptr)
694
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
695
- input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
696
-
697
-
820
+ unmarkDeletedInternal(internalId);
821
+ }
698
822
 
699
823
 
700
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
701
824
 
825
+ /*
826
+ * Remove the deleted mark of the node.
827
+ */
828
+ void unmarkDeletedInternal(tableint internalId) {
829
+ assert(internalId < cur_element_count);
830
+ if (isMarkedDeleted(internalId)) {
831
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
832
+ *ll_cur &= ~DELETE_MARK;
833
+ num_deleted_ -= 1;
834
+ if (allow_replace_deleted_) {
835
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
836
+ deleted_elements.erase(internalId);
837
+ }
838
+ } else {
839
+ throw std::runtime_error("The requested to undelete element is not deleted");
840
+ }
841
+ }
702
842
 
703
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
704
- std::vector<std::mutex>(max_elements).swap(link_list_locks_);
705
- std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);
706
843
 
844
+ /*
845
+ * Checks the first 16 bits of the memory to see if the element is marked deleted.
846
+ */
847
+ bool isMarkedDeleted(tableint internalId) const {
848
+ unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2;
849
+ return *ll_cur & DELETE_MARK;
850
+ }
707
851
 
708
- visited_list_pool_ = new VisitedListPool(1, max_elements);
709
852
 
853
+ unsigned short int getListCount(linklistsizeint * ptr) const {
854
+ return *((unsigned short int *)ptr);
855
+ }
710
856
 
711
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
712
- if (linkLists_ == nullptr)
713
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
714
- element_levels_ = std::vector<int>(max_elements);
715
- revSize_ = 1.0 / mult_;
716
- ef_ = 10;
717
- for (size_t i = 0; i < cur_element_count; i++) {
718
- label_lookup_[getExternalLabel(i)]=i;
719
- unsigned int linkListSize;
720
- readBinaryPOD(input, linkListSize);
721
- if (linkListSize == 0) {
722
- element_levels_[i] = 0;
723
857
 
724
- linkLists_[i] = nullptr;
725
- } else {
726
- element_levels_[i] = linkListSize / size_links_per_element_;
727
- linkLists_[i] = (char *) malloc(linkListSize);
728
- if (linkLists_[i] == nullptr)
729
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
730
- input.read(linkLists_[i], linkListSize);
731
- }
732
- }
858
+ void setListCount(linklistsizeint * ptr, unsigned short int size) const {
859
+ *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
860
+ }
733
861
 
734
- has_deletions_=false;
735
862
 
736
- for (size_t i = 0; i < cur_element_count; i++) {
737
- if(isMarkedDeleted(i))
738
- has_deletions_=true;
739
- }
740
-
741
- input.close();
863
+ /*
864
+ * Adds point. Updates the point if it is already in the index.
865
+ * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
866
+ */
867
+ void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) {
868
+ if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
869
+ throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
870
+ }
742
871
 
872
+ // lock all operations with element by label
873
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
874
+ if (!replace_deleted) {
875
+ addPoint(data_point, label, -1);
743
876
  return;
744
877
  }
745
-
746
- template<typename data_t>
747
- std::vector<data_t> getDataByLabel(labeltype label) const
748
- {
749
- tableint label_c;
750
- auto search = label_lookup_.find(label);
751
- if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
752
- throw std::runtime_error("Label not found");
753
- }
754
- label_c = search->second;
755
-
756
- char* data_ptrv = getDataByInternalId(label_c);
757
- size_t dim = *((size_t *) dist_func_param_);
758
- std::vector<data_t> data;
759
- data_t* data_ptr = (data_t*) data_ptrv;
760
- for (int i = 0; i < dim; i++) {
761
- data.push_back(*data_ptr);
762
- data_ptr += 1;
763
- }
764
- return data;
878
+ // check if there is vacant place
879
+ tableint internal_id_replaced;
880
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
881
+ bool is_vacant_place = !deleted_elements.empty();
882
+ if (is_vacant_place) {
883
+ internal_id_replaced = *deleted_elements.begin();
884
+ deleted_elements.erase(internal_id_replaced);
765
885
  }
766
-
767
- static const unsigned char DELETE_MARK = 0x01;
768
- // static const unsigned char REUSE_MARK = 0x10;
769
- /**
770
- * Marks an element with the given label deleted, does NOT really change the current graph.
771
- * @param label
772
- */
773
- void markDelete(labeltype label)
774
- {
775
- has_deletions_=true;
776
- auto search = label_lookup_.find(label);
777
- if (search == label_lookup_.end()) {
778
- throw std::runtime_error("Label not found");
779
- }
780
- markDeletedInternal(search->second);
886
+ lock_deleted_elements.unlock();
887
+
888
+ // if there is no vacant place then add or update point
889
+ // else add point to vacant place
890
+ if (!is_vacant_place) {
891
+ addPoint(data_point, label, -1);
892
+ } else {
893
+ // we assume that there are no concurrent operations on deleted element
894
+ labeltype label_replaced = getExternalLabel(internal_id_replaced);
895
+ setExternalLabel(internal_id_replaced, label);
896
+
897
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
898
+ label_lookup_.erase(label_replaced);
899
+ label_lookup_[label] = internal_id_replaced;
900
+ lock_table.unlock();
901
+
902
+ unmarkDeletedInternal(internal_id_replaced);
903
+ updatePoint(data_point, internal_id_replaced, 1.0);
781
904
  }
905
+ }
782
906
 
783
- /**
784
- * Uses the first 8 bits of the memory for the linked list to store the mark,
785
- * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
786
- * @param internalId
787
- */
788
- void markDeletedInternal(tableint internalId) {
789
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
790
- *ll_cur |= DELETE_MARK;
791
- }
792
907
 
793
- /**
794
- * Remove the deleted mark of the node.
795
- * @param internalId
796
- */
797
- void unmarkDeletedInternal(tableint internalId) {
798
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
799
- *ll_cur &= ~DELETE_MARK;
800
- }
908
+ void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
909
+ // update the feature vector associated with existing point with new vector
910
+ memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
801
911
 
802
- /**
803
- * Checks the first 8 bits of the memory to see if the element is marked deleted.
804
- * @param internalId
805
- * @return
806
- */
807
- bool isMarkedDeleted(tableint internalId) const {
808
- unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
809
- return *ll_cur & DELETE_MARK;
810
- }
912
+ int maxLevelCopy = maxlevel_;
913
+ tableint entryPointCopy = enterpoint_node_;
914
+ // If point to be updated is entry point and graph just contains single element then just return.
915
+ if (entryPointCopy == internalId && cur_element_count == 1)
916
+ return;
811
917
 
812
- unsigned short int getListCount(linklistsizeint * ptr) const {
813
- return *((unsigned short int *)ptr);
814
- }
918
+ int elemLevel = element_levels_[internalId];
919
+ std::uniform_real_distribution<float> distribution(0.0, 1.0);
920
+ for (int layer = 0; layer <= elemLevel; layer++) {
921
+ std::unordered_set<tableint> sCand;
922
+ std::unordered_set<tableint> sNeigh;
923
+ std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
924
+ if (listOneHop.size() == 0)
925
+ continue;
815
926
 
816
- void setListCount(linklistsizeint * ptr, unsigned short int size) const {
817
- *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
818
- }
927
+ sCand.insert(internalId);
819
928
 
820
- void addPoint(const void *data_point, labeltype label) {
821
- addPoint(data_point, label,-1);
822
- }
929
+ for (auto&& elOneHop : listOneHop) {
930
+ sCand.insert(elOneHop);
823
931
 
824
- void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
825
- // update the feature vector associated with existing point with new vector
826
- memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
827
-
828
- int maxLevelCopy = maxlevel_;
829
- tableint entryPointCopy = enterpoint_node_;
830
- // If point to be updated is entry point and graph just contains single element then just return.
831
- if (entryPointCopy == internalId && cur_element_count == 1)
832
- return;
833
-
834
- int elemLevel = element_levels_[internalId];
835
- std::uniform_real_distribution<float> distribution(0.0, 1.0);
836
- for (int layer = 0; layer <= elemLevel; layer++) {
837
- std::unordered_set<tableint> sCand;
838
- std::unordered_set<tableint> sNeigh;
839
- std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
840
- if (listOneHop.size() == 0)
932
+ if (distribution(update_probability_generator_) > updateNeighborProbability)
841
933
  continue;
842
934
 
843
- sCand.insert(internalId);
844
-
845
- for (auto&& elOneHop : listOneHop) {
846
- sCand.insert(elOneHop);
935
+ sNeigh.insert(elOneHop);
847
936
 
848
- if (distribution(update_probability_generator_) > updateNeighborProbability)
849
- continue;
850
-
851
- sNeigh.insert(elOneHop);
852
-
853
- std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
854
- for (auto&& elTwoHop : listTwoHop) {
855
- sCand.insert(elTwoHop);
856
- }
937
+ std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
938
+ for (auto&& elTwoHop : listTwoHop) {
939
+ sCand.insert(elTwoHop);
857
940
  }
941
+ }
858
942
 
859
- for (auto&& neigh : sNeigh) {
860
- // if (neigh == internalId)
861
- // continue;
943
+ for (auto&& neigh : sNeigh) {
944
+ // if (neigh == internalId)
945
+ // continue;
862
946
 
863
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
864
- size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
865
- size_t elementsToKeep = std::min(ef_construction_, size);
866
- for (auto&& cand : sCand) {
867
- if (cand == neigh)
868
- continue;
869
-
870
- dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
871
- if (candidates.size() < elementsToKeep) {
947
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
948
+ size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
949
+ size_t elementsToKeep = std::min(ef_construction_, size);
950
+ for (auto&& cand : sCand) {
951
+ if (cand == neigh)
952
+ continue;
953
+
954
+ dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
955
+ if (candidates.size() < elementsToKeep) {
956
+ candidates.emplace(distance, cand);
957
+ } else {
958
+ if (distance < candidates.top().first) {
959
+ candidates.pop();
872
960
  candidates.emplace(distance, cand);
873
- } else {
874
- if (distance < candidates.top().first) {
875
- candidates.pop();
876
- candidates.emplace(distance, cand);
877
- }
878
961
  }
879
962
  }
963
+ }
880
964
 
881
- // Retrieve neighbours using heuristic and set connections.
882
- getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
883
-
884
- {
885
- std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
886
- linklistsizeint *ll_cur;
887
- ll_cur = get_linklist_at_level(neigh, layer);
888
- size_t candSize = candidates.size();
889
- setListCount(ll_cur, candSize);
890
- tableint *data = (tableint *) (ll_cur + 1);
891
- for (size_t idx = 0; idx < candSize; idx++) {
892
- data[idx] = candidates.top().second;
893
- candidates.pop();
894
- }
965
+ // Retrieve neighbours using heuristic and set connections.
966
+ getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
967
+
968
+ {
969
+ std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
970
+ linklistsizeint *ll_cur;
971
+ ll_cur = get_linklist_at_level(neigh, layer);
972
+ size_t candSize = candidates.size();
973
+ setListCount(ll_cur, candSize);
974
+ tableint *data = (tableint *) (ll_cur + 1);
975
+ for (size_t idx = 0; idx < candSize; idx++) {
976
+ data[idx] = candidates.top().second;
977
+ candidates.pop();
895
978
  }
896
979
  }
897
980
  }
981
+ }
898
982
 
899
- repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
900
- };
983
+ repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
984
+ }
901
985
 
902
- void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
903
- tableint currObj = entryPointInternalId;
904
- if (dataPointLevel < maxLevel) {
905
- dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
906
- for (int level = maxLevel; level > dataPointLevel; level--) {
907
- bool changed = true;
908
- while (changed) {
909
- changed = false;
910
- unsigned int *data;
911
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
912
- data = get_linklist_at_level(currObj,level);
913
- int size = getListCount(data);
914
- tableint *datal = (tableint *) (data + 1);
986
+
987
+ void repairConnectionsForUpdate(
988
+ const void *dataPoint,
989
+ tableint entryPointInternalId,
990
+ tableint dataPointInternalId,
991
+ int dataPointLevel,
992
+ int maxLevel) {
993
+ tableint currObj = entryPointInternalId;
994
+ if (dataPointLevel < maxLevel) {
995
+ dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
996
+ for (int level = maxLevel; level > dataPointLevel; level--) {
997
+ bool changed = true;
998
+ while (changed) {
999
+ changed = false;
1000
+ unsigned int *data;
1001
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1002
+ data = get_linklist_at_level(currObj, level);
1003
+ int size = getListCount(data);
1004
+ tableint *datal = (tableint *) (data + 1);
915
1005
  #ifdef USE_SSE
916
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
1006
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
917
1007
  #endif
918
- for (int i = 0; i < size; i++) {
1008
+ for (int i = 0; i < size; i++) {
919
1009
  #ifdef USE_SSE
920
- _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
1010
+ _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
921
1011
  #endif
922
- tableint cand = datal[i];
923
- dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
924
- if (d < curdist) {
925
- curdist = d;
926
- currObj = cand;
927
- changed = true;
928
- }
1012
+ tableint cand = datal[i];
1013
+ dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
1014
+ if (d < curdist) {
1015
+ curdist = d;
1016
+ currObj = cand;
1017
+ changed = true;
929
1018
  }
930
1019
  }
931
1020
  }
932
1021
  }
1022
+ }
933
1023
 
934
- if (dataPointLevel > maxLevel)
935
- throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
936
-
937
- for (int level = dataPointLevel; level >= 0; level--) {
938
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
939
- currObj, dataPoint, level);
1024
+ if (dataPointLevel > maxLevel)
1025
+ throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
940
1026
 
941
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
942
- while (topCandidates.size() > 0) {
943
- if (topCandidates.top().second != dataPointInternalId)
944
- filteredTopCandidates.push(topCandidates.top());
1027
+ for (int level = dataPointLevel; level >= 0; level--) {
1028
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
1029
+ currObj, dataPoint, level);
945
1030
 
946
- topCandidates.pop();
947
- }
1031
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
1032
+ while (topCandidates.size() > 0) {
1033
+ if (topCandidates.top().second != dataPointInternalId)
1034
+ filteredTopCandidates.push(topCandidates.top());
948
1035
 
949
- // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
950
- // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
951
- if (filteredTopCandidates.size() > 0) {
952
- bool epDeleted = isMarkedDeleted(entryPointInternalId);
953
- if (epDeleted) {
954
- filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
955
- if (filteredTopCandidates.size() > ef_construction_)
956
- filteredTopCandidates.pop();
957
- }
1036
+ topCandidates.pop();
1037
+ }
958
1038
 
959
- currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
1039
+ // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
1040
+ // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
1041
+ if (filteredTopCandidates.size() > 0) {
1042
+ bool epDeleted = isMarkedDeleted(entryPointInternalId);
1043
+ if (epDeleted) {
1044
+ filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
1045
+ if (filteredTopCandidates.size() > ef_construction_)
1046
+ filteredTopCandidates.pop();
960
1047
  }
1048
+
1049
+ currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
961
1050
  }
962
1051
  }
1052
+ }
1053
+
963
1054
 
964
- std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
965
- std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
966
- unsigned int *data = get_linklist_at_level(internalId, level);
967
- int size = getListCount(data);
968
- std::vector<tableint> result(size);
969
- tableint *ll = (tableint *) (data + 1);
970
- memcpy(result.data(), ll,size * sizeof(tableint));
971
- return result;
972
- };
973
-
974
- tableint addPoint(const void *data_point, labeltype label, int level) {
975
-
976
- tableint cur_c = 0;
977
- {
978
- // Checking if the element with the same label already exists
979
- // if so, updating it *instead* of creating a new element.
980
- std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
981
- auto search = label_lookup_.find(label);
982
- if (search != label_lookup_.end()) {
983
- tableint existingInternalId = search->second;
984
- templock_curr.unlock();
985
-
986
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
1055
+ std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
1056
+ std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
1057
+ unsigned int *data = get_linklist_at_level(internalId, level);
1058
+ int size = getListCount(data);
1059
+ std::vector<tableint> result(size);
1060
+ tableint *ll = (tableint *) (data + 1);
1061
+ memcpy(result.data(), ll, size * sizeof(tableint));
1062
+ return result;
1063
+ }
987
1064
 
1065
+
1066
+ tableint addPoint(const void *data_point, labeltype label, int level) {
1067
+ tableint cur_c = 0;
1068
+ {
1069
+ // Checking if the element with the same label already exists
1070
+ // if so, updating it *instead* of creating a new element.
1071
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
1072
+ auto search = label_lookup_.find(label);
1073
+ if (search != label_lookup_.end()) {
1074
+ tableint existingInternalId = search->second;
1075
+ if (allow_replace_deleted_) {
988
1076
  if (isMarkedDeleted(existingInternalId)) {
989
- unmarkDeletedInternal(existingInternalId);
1077
+ throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
990
1078
  }
991
- updatePoint(data_point, existingInternalId, 1.0);
992
-
993
- return existingInternalId;
994
1079
  }
1080
+ lock_table.unlock();
995
1081
 
996
- if (cur_element_count >= max_elements_) {
997
- throw std::runtime_error("The number of elements exceeds the specified limit");
998
- };
1082
+ if (isMarkedDeleted(existingInternalId)) {
1083
+ unmarkDeletedInternal(existingInternalId);
1084
+ }
1085
+ updatePoint(data_point, existingInternalId, 1.0);
999
1086
 
1000
- cur_c = cur_element_count;
1001
- cur_element_count++;
1002
- label_lookup_[label] = cur_c;
1087
+ return existingInternalId;
1003
1088
  }
1004
1089
 
1005
- // Take update lock to prevent race conditions on an element with insertion/update at the same time.
1006
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
1007
- std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1008
- int curlevel = getRandomLevel(mult_);
1009
- if (level > 0)
1010
- curlevel = level;
1090
+ if (cur_element_count >= max_elements_) {
1091
+ throw std::runtime_error("The number of elements exceeds the specified limit");
1092
+ }
1011
1093
 
1012
- element_levels_[cur_c] = curlevel;
1094
+ cur_c = cur_element_count;
1095
+ cur_element_count++;
1096
+ label_lookup_[label] = cur_c;
1097
+ }
1013
1098
 
1099
+ std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1100
+ int curlevel = getRandomLevel(mult_);
1101
+ if (level > 0)
1102
+ curlevel = level;
1014
1103
 
1015
- std::unique_lock <std::mutex> templock(global);
1016
- int maxlevelcopy = maxlevel_;
1017
- if (curlevel <= maxlevelcopy)
1018
- templock.unlock();
1019
- tableint currObj = enterpoint_node_;
1020
- tableint enterpoint_copy = enterpoint_node_;
1104
+ element_levels_[cur_c] = curlevel;
1021
1105
 
1106
+ std::unique_lock <std::mutex> templock(global);
1107
+ int maxlevelcopy = maxlevel_;
1108
+ if (curlevel <= maxlevelcopy)
1109
+ templock.unlock();
1110
+ tableint currObj = enterpoint_node_;
1111
+ tableint enterpoint_copy = enterpoint_node_;
1022
1112
 
1023
- memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1113
+ memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1024
1114
 
1025
- // Initialisation of the data and label
1026
- memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1027
- memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1115
+ // Initialisation of the data and label
1116
+ memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1117
+ memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1028
1118
 
1119
+ if (curlevel) {
1120
+ linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1121
+ if (linkLists_[cur_c] == nullptr)
1122
+ throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1123
+ memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1124
+ }
1029
1125
 
1030
- if (curlevel) {
1031
- linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1032
- if (linkLists_[cur_c] == nullptr)
1033
- throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1034
- memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1035
- }
1126
+ if ((signed)currObj != -1) {
1127
+ if (curlevel < maxlevelcopy) {
1128
+ dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1129
+ for (int level = maxlevelcopy; level > curlevel; level--) {
1130
+ bool changed = true;
1131
+ while (changed) {
1132
+ changed = false;
1133
+ unsigned int *data;
1134
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1135
+ data = get_linklist(currObj, level);
1136
+ int size = getListCount(data);
1036
1137
 
1037
- if ((signed)currObj != -1) {
1038
-
1039
- if (curlevel < maxlevelcopy) {
1040
-
1041
- dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1042
- for (int level = maxlevelcopy; level > curlevel; level--) {
1043
-
1044
-
1045
- bool changed = true;
1046
- while (changed) {
1047
- changed = false;
1048
- unsigned int *data;
1049
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1050
- data = get_linklist(currObj,level);
1051
- int size = getListCount(data);
1052
-
1053
- tableint *datal = (tableint *) (data + 1);
1054
- for (int i = 0; i < size; i++) {
1055
- tableint cand = datal[i];
1056
- if (cand < 0 || cand > max_elements_)
1057
- throw std::runtime_error("cand error");
1058
- dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1059
- if (d < curdist) {
1060
- curdist = d;
1061
- currObj = cand;
1062
- changed = true;
1063
- }
1138
+ tableint *datal = (tableint *) (data + 1);
1139
+ for (int i = 0; i < size; i++) {
1140
+ tableint cand = datal[i];
1141
+ if (cand < 0 || cand > max_elements_)
1142
+ throw std::runtime_error("cand error");
1143
+ dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1144
+ if (d < curdist) {
1145
+ curdist = d;
1146
+ currObj = cand;
1147
+ changed = true;
1064
1148
  }
1065
1149
  }
1066
1150
  }
1067
1151
  }
1152
+ }
1068
1153
 
1069
- bool epDeleted = isMarkedDeleted(enterpoint_copy);
1070
- for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1071
- if (level > maxlevelcopy || level < 0) // possible?
1072
- throw std::runtime_error("Level error");
1073
-
1074
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1075
- currObj, data_point, level);
1076
- if (epDeleted) {
1077
- top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1078
- if (top_candidates.size() > ef_construction_)
1079
- top_candidates.pop();
1080
- }
1081
- currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1154
+ bool epDeleted = isMarkedDeleted(enterpoint_copy);
1155
+ for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1156
+ if (level > maxlevelcopy || level < 0) // possible?
1157
+ throw std::runtime_error("Level error");
1158
+
1159
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1160
+ currObj, data_point, level);
1161
+ if (epDeleted) {
1162
+ top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1163
+ if (top_candidates.size() > ef_construction_)
1164
+ top_candidates.pop();
1082
1165
  }
1083
-
1084
-
1085
- } else {
1086
- // Do nothing for the first element
1087
- enterpoint_node_ = 0;
1088
- maxlevel_ = curlevel;
1089
-
1166
+ currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1090
1167
  }
1168
+ } else {
1169
+ // Do nothing for the first element
1170
+ enterpoint_node_ = 0;
1171
+ maxlevel_ = curlevel;
1172
+ }
1091
1173
 
1092
- //Releasing lock for the maximum level
1093
- if (curlevel > maxlevelcopy) {
1094
- enterpoint_node_ = cur_c;
1095
- maxlevel_ = curlevel;
1096
- }
1097
- return cur_c;
1098
- };
1174
+ // Releasing lock for the maximum level
1175
+ if (curlevel > maxlevelcopy) {
1176
+ enterpoint_node_ = cur_c;
1177
+ maxlevel_ = curlevel;
1178
+ }
1179
+ return cur_c;
1180
+ }
1099
1181
 
1100
- std::priority_queue<std::pair<dist_t, labeltype >>
1101
- searchKnn(const void *query_data, size_t k) const {
1102
- std::priority_queue<std::pair<dist_t, labeltype >> result;
1103
- if (cur_element_count == 0) return result;
1104
1182
 
1105
- tableint currObj = enterpoint_node_;
1106
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1183
+ std::priority_queue<std::pair<dist_t, labeltype >>
1184
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
1185
+ std::priority_queue<std::pair<dist_t, labeltype >> result;
1186
+ if (cur_element_count == 0) return result;
1107
1187
 
1108
- for (int level = maxlevel_; level > 0; level--) {
1109
- bool changed = true;
1110
- while (changed) {
1111
- changed = false;
1112
- unsigned int *data;
1188
+ tableint currObj = enterpoint_node_;
1189
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1113
1190
 
1114
- data = (unsigned int *) get_linklist(currObj, level);
1115
- int size = getListCount(data);
1116
- metric_hops++;
1117
- metric_distance_computations+=size;
1191
+ for (int level = maxlevel_; level > 0; level--) {
1192
+ bool changed = true;
1193
+ while (changed) {
1194
+ changed = false;
1195
+ unsigned int *data;
1118
1196
 
1119
- tableint *datal = (tableint *) (data + 1);
1120
- for (int i = 0; i < size; i++) {
1121
- tableint cand = datal[i];
1122
- if (cand < 0 || cand > max_elements_)
1123
- throw std::runtime_error("cand error");
1124
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1197
+ data = (unsigned int *) get_linklist(currObj, level);
1198
+ int size = getListCount(data);
1199
+ metric_hops++;
1200
+ metric_distance_computations+=size;
1125
1201
 
1126
- if (d < curdist) {
1127
- curdist = d;
1128
- currObj = cand;
1129
- changed = true;
1130
- }
1202
+ tableint *datal = (tableint *) (data + 1);
1203
+ for (int i = 0; i < size; i++) {
1204
+ tableint cand = datal[i];
1205
+ if (cand < 0 || cand > max_elements_)
1206
+ throw std::runtime_error("cand error");
1207
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1208
+
1209
+ if (d < curdist) {
1210
+ curdist = d;
1211
+ currObj = cand;
1212
+ changed = true;
1131
1213
  }
1132
1214
  }
1133
1215
  }
1216
+ }
1134
1217
 
1135
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1136
- if (has_deletions_) {
1137
- top_candidates=searchBaseLayerST<true,true>(
1138
- currObj, query_data, std::max(ef_, k));
1139
- }
1140
- else{
1141
- top_candidates=searchBaseLayerST<false,true>(
1142
- currObj, query_data, std::max(ef_, k));
1143
- }
1218
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1219
+ if (num_deleted_) {
1220
+ top_candidates = searchBaseLayerST<true, true>(
1221
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1222
+ } else {
1223
+ top_candidates = searchBaseLayerST<false, true>(
1224
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1225
+ }
1144
1226
 
1145
- while (top_candidates.size() > k) {
1146
- top_candidates.pop();
1147
- }
1148
- while (top_candidates.size() > 0) {
1149
- std::pair<dist_t, tableint> rez = top_candidates.top();
1150
- result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1151
- top_candidates.pop();
1152
- }
1153
- return result;
1154
- };
1155
-
1156
- void checkIntegrity(){
1157
- int connections_checked=0;
1158
- std::vector <int > inbound_connections_num(cur_element_count,0);
1159
- for(int i = 0;i < cur_element_count; i++){
1160
- for(int l = 0;l <= element_levels_[i]; l++){
1161
- linklistsizeint *ll_cur = get_linklist_at_level(i,l);
1162
- int size = getListCount(ll_cur);
1163
- tableint *data = (tableint *) (ll_cur + 1);
1164
- std::unordered_set<tableint> s;
1165
- for (int j=0; j<size; j++){
1166
- assert(data[j] > 0);
1167
- assert(data[j] < cur_element_count);
1168
- assert (data[j] != i);
1169
- inbound_connections_num[data[j]]++;
1170
- s.insert(data[j]);
1171
- connections_checked++;
1227
+ while (top_candidates.size() > k) {
1228
+ top_candidates.pop();
1229
+ }
1230
+ while (top_candidates.size() > 0) {
1231
+ std::pair<dist_t, tableint> rez = top_candidates.top();
1232
+ result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1233
+ top_candidates.pop();
1234
+ }
1235
+ return result;
1236
+ }
1172
1237
 
1173
- }
1174
- assert(s.size() == size);
1238
+
1239
+ void checkIntegrity() {
1240
+ int connections_checked = 0;
1241
+ std::vector <int > inbound_connections_num(cur_element_count, 0);
1242
+ for (int i = 0; i < cur_element_count; i++) {
1243
+ for (int l = 0; l <= element_levels_[i]; l++) {
1244
+ linklistsizeint *ll_cur = get_linklist_at_level(i, l);
1245
+ int size = getListCount(ll_cur);
1246
+ tableint *data = (tableint *) (ll_cur + 1);
1247
+ std::unordered_set<tableint> s;
1248
+ for (int j = 0; j < size; j++) {
1249
+ assert(data[j] > 0);
1250
+ assert(data[j] < cur_element_count);
1251
+ assert(data[j] != i);
1252
+ inbound_connections_num[data[j]]++;
1253
+ s.insert(data[j]);
1254
+ connections_checked++;
1175
1255
  }
1256
+ assert(s.size() == size);
1176
1257
  }
1177
- if(cur_element_count > 1){
1178
- int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
1179
- for(int i=0; i < cur_element_count; i++){
1180
- assert(inbound_connections_num[i] > 0);
1181
- min1=std::min(inbound_connections_num[i],min1);
1182
- max1=std::max(inbound_connections_num[i],max1);
1183
- }
1184
- std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1258
+ }
1259
+ if (cur_element_count > 1) {
1260
+ int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
1261
+ for (int i=0; i < cur_element_count; i++) {
1262
+ assert(inbound_connections_num[i] > 0);
1263
+ min1 = std::min(inbound_connections_num[i], min1);
1264
+ max1 = std::max(inbound_connections_num[i], max1);
1185
1265
  }
1186
- std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1187
-
1266
+ std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1188
1267
  }
1189
-
1190
- };
1191
-
1192
- }
1268
+ std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1269
+ }
1270
+ };
1271
+ } // namespace hnswlib