umappp 0.1.6 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,1183 +10,1262 @@
10
10
  #include <list>
11
11
 
12
12
  namespace hnswlib {
13
- typedef unsigned int tableint;
14
- typedef unsigned int linklistsizeint;
13
+ typedef unsigned int tableint;
14
+ typedef unsigned int linklistsizeint;
15
+
16
+ template<typename dist_t>
17
+ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
+ public:
19
+ static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
20
+ static const unsigned char DELETE_MARK = 0x01;
21
+
22
+ size_t max_elements_{0};
23
+ mutable std::atomic<size_t> cur_element_count{0}; // current number of elements
24
+ size_t size_data_per_element_{0};
25
+ size_t size_links_per_element_{0};
26
+ mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements
27
+ size_t M_{0};
28
+ size_t maxM_{0};
29
+ size_t maxM0_{0};
30
+ size_t ef_construction_{0};
31
+ size_t ef_{ 0 };
32
+
33
+ double mult_{0.0}, revSize_{0.0};
34
+ int maxlevel_{0};
35
+
36
+ VisitedListPool *visited_list_pool_{nullptr};
37
+
38
+ // Locks operations with element by label value
39
+ mutable std::vector<std::mutex> label_op_locks_;
40
+
41
+ std::mutex global;
42
+ std::vector<std::mutex> link_list_locks_;
43
+
44
+ tableint enterpoint_node_{0};
45
+
46
+ size_t size_links_level0_{0};
47
+ size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 };
48
+
49
+ char *data_level0_memory_{nullptr};
50
+ char **linkLists_{nullptr};
51
+ std::vector<int> element_levels_; // keeps level of each element
52
+
53
+ size_t data_size_{0};
54
+
55
+ DISTFUNC<dist_t> fstdistfunc_;
56
+ void *dist_func_param_{nullptr};
57
+
58
+ mutable std::mutex label_lookup_lock; // lock for label_lookup_
59
+ std::unordered_map<labeltype, tableint> label_lookup_;
60
+
61
+ std::default_random_engine level_generator_;
62
+ std::default_random_engine update_probability_generator_;
63
+
64
+ mutable std::atomic<long> metric_distance_computations{0};
65
+ mutable std::atomic<long> metric_hops{0};
66
+
67
+ bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions
68
+
69
+ std::mutex deleted_elements_lock; // lock for deleted_elements
70
+ std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
71
+
72
+
73
+ HierarchicalNSW(SpaceInterface<dist_t> *s) {
74
+ }
75
+
76
+
77
+ HierarchicalNSW(
78
+ SpaceInterface<dist_t> *s,
79
+ const std::string &location,
80
+ bool nmslib = false,
81
+ size_t max_elements = 0,
82
+ bool allow_replace_deleted = false)
83
+ : allow_replace_deleted_(allow_replace_deleted) {
84
+ loadIndex(location, s, max_elements);
85
+ }
86
+
87
+
88
+ HierarchicalNSW(
89
+ SpaceInterface<dist_t> *s,
90
+ size_t max_elements,
91
+ size_t M = 16,
92
+ size_t ef_construction = 200,
93
+ size_t random_seed = 100,
94
+ bool allow_replace_deleted = false)
95
+ : label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
96
+ link_list_locks_(max_elements),
97
+ element_levels_(max_elements),
98
+ allow_replace_deleted_(allow_replace_deleted) {
99
+ max_elements_ = max_elements;
100
+ num_deleted_ = 0;
101
+ data_size_ = s->get_data_size();
102
+ fstdistfunc_ = s->get_dist_func();
103
+ dist_func_param_ = s->get_dist_func_param();
104
+ M_ = M;
105
+ maxM_ = M_;
106
+ maxM0_ = M_ * 2;
107
+ ef_construction_ = std::max(ef_construction, M_);
108
+ ef_ = 10;
109
+
110
+ level_generator_.seed(random_seed);
111
+ update_probability_generator_.seed(random_seed + 1);
112
+
113
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
114
+ size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
115
+ offsetData_ = size_links_level0_;
116
+ label_offset_ = size_links_level0_ + data_size_;
117
+ offsetLevel0_ = 0;
118
+
119
+ data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
120
+ if (data_level0_memory_ == nullptr)
121
+ throw std::runtime_error("Not enough memory");
122
+
123
+ cur_element_count = 0;
124
+
125
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
15
126
 
16
- template<typename dist_t>
17
- class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
- public:
19
- static const tableint max_update_element_locks = 65536;
20
- HierarchicalNSW(SpaceInterface<dist_t> *s) {
127
+ // initializations for special treatment of the first node
128
+ enterpoint_node_ = -1;
129
+ maxlevel_ = -1;
21
130
 
22
- }
23
-
24
- HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
25
- loadIndex(location, s, max_elements);
26
- }
27
-
28
- HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
29
- link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
30
- max_elements_ = max_elements;
31
-
32
- has_deletions_=false;
33
- data_size_ = s->get_data_size();
34
- fstdistfunc_ = s->get_dist_func();
35
- dist_func_param_ = s->get_dist_func_param();
36
- M_ = M;
37
- maxM_ = M_;
38
- maxM0_ = M_ * 2;
39
- ef_construction_ = std::max(ef_construction,M_);
40
- ef_ = 10;
41
-
42
- level_generator_.seed(random_seed);
43
- update_probability_generator_.seed(random_seed + 1);
44
-
45
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
46
- size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
47
- offsetData_ = size_links_level0_;
48
- label_offset_ = size_links_level0_ + data_size_;
49
- offsetLevel0_ = 0;
50
-
51
- data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
52
- if (data_level0_memory_ == nullptr)
53
- throw std::runtime_error("Not enough memory");
54
-
55
- cur_element_count = 0;
56
-
57
- visited_list_pool_ = new VisitedListPool(1, max_elements);
131
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
132
+ if (linkLists_ == nullptr)
133
+ throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
134
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
135
+ mult_ = 1 / log(1.0 * M_);
136
+ revSize_ = 1.0 / mult_;
137
+ }
58
138
 
59
139
 
60
-
61
- //initializations for special treatment of the first node
62
- enterpoint_node_ = -1;
63
- maxlevel_ = -1;
64
-
65
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
66
- if (linkLists_ == nullptr)
67
- throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
68
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
69
- mult_ = 1 / log(1.0 * M_);
70
- revSize_ = 1.0 / mult_;
140
+ ~HierarchicalNSW() {
141
+ free(data_level0_memory_);
142
+ for (tableint i = 0; i < cur_element_count; i++) {
143
+ if (element_levels_[i] > 0)
144
+ free(linkLists_[i]);
71
145
  }
146
+ free(linkLists_);
147
+ delete visited_list_pool_;
148
+ }
72
149
 
73
- struct CompareByFirst {
74
- constexpr bool operator()(std::pair<dist_t, tableint> const &a,
75
- std::pair<dist_t, tableint> const &b) const noexcept {
76
- return a.first < b.first;
77
- }
78
- };
79
150
 
80
- ~HierarchicalNSW() {
81
-
82
- free(data_level0_memory_);
83
- for (tableint i = 0; i < cur_element_count; i++) {
84
- if (element_levels_[i] > 0)
85
- free(linkLists_[i]);
86
- }
87
- free(linkLists_);
88
- delete visited_list_pool_;
151
+ struct CompareByFirst {
152
+ constexpr bool operator()(std::pair<dist_t, tableint> const& a,
153
+ std::pair<dist_t, tableint> const& b) const noexcept {
154
+ return a.first < b.first;
89
155
  }
156
+ };
90
157
 
91
- size_t max_elements_;
92
- size_t cur_element_count;
93
- size_t size_data_per_element_;
94
- size_t size_links_per_element_;
95
158
 
96
- size_t M_;
97
- size_t maxM_;
98
- size_t maxM0_;
99
- size_t ef_construction_;
159
+ void setEf(size_t ef) {
160
+ ef_ = ef;
161
+ }
100
162
 
101
- double mult_, revSize_;
102
- int maxlevel_;
103
163
 
164
+ inline std::mutex& getLabelOpMutex(labeltype label) const {
165
+ // calculate hash
166
+ size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
167
+ return label_op_locks_[lock_id];
168
+ }
104
169
 
105
- VisitedListPool *visited_list_pool_;
106
- std::mutex cur_element_count_guard_;
107
170
 
108
- std::vector<std::mutex> link_list_locks_;
171
+ inline labeltype getExternalLabel(tableint internal_id) const {
172
+ labeltype return_label;
173
+ memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
174
+ return return_label;
175
+ }
109
176
 
110
- // Locks to prevent race condition during update/insert of an element at same time.
111
- // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
112
- std::vector<std::mutex> link_list_update_locks_;
113
- tableint enterpoint_node_;
114
177
 
178
+ inline void setExternalLabel(tableint internal_id, labeltype label) const {
179
+ memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
180
+ }
115
181
 
116
- size_t size_links_level0_;
117
- size_t offsetData_, offsetLevel0_;
118
182
 
183
+ inline labeltype *getExternalLabeLp(tableint internal_id) const {
184
+ return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
185
+ }
119
186
 
120
- char *data_level0_memory_;
121
- char **linkLists_;
122
- std::vector<int> element_levels_;
123
187
 
124
- size_t data_size_;
188
+ inline char *getDataByInternalId(tableint internal_id) const {
189
+ return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
190
+ }
125
191
 
126
- bool has_deletions_;
127
192
 
193
+ int getRandomLevel(double reverse_size) {
194
+ std::uniform_real_distribution<double> distribution(0.0, 1.0);
195
+ double r = -log(distribution(level_generator_)) * reverse_size;
196
+ return (int) r;
197
+ }
128
198
 
129
- size_t label_offset_;
130
- DISTFUNC<dist_t> fstdistfunc_;
131
- void *dist_func_param_;
132
- std::unordered_map<labeltype, tableint> label_lookup_;
199
+ size_t getMaxElements() {
200
+ return max_elements_;
201
+ }
133
202
 
134
- std::default_random_engine level_generator_;
135
- std::default_random_engine update_probability_generator_;
203
+ size_t getCurrentElementCount() {
204
+ return cur_element_count;
205
+ }
136
206
 
137
- inline labeltype getExternalLabel(tableint internal_id) const {
138
- labeltype return_label;
139
- memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
140
- return return_label;
141
- }
207
+ size_t getDeletedCount() {
208
+ return num_deleted_;
209
+ }
142
210
 
143
- inline void setExternalLabel(tableint internal_id, labeltype label) const {
144
- memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
145
- }
211
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
212
+ searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
213
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
214
+ vl_type *visited_array = vl->mass;
215
+ vl_type visited_array_tag = vl->curV;
146
216
 
147
- inline labeltype *getExternalLabeLp(tableint internal_id) const {
148
- return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
149
- }
217
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
218
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
150
219
 
151
- inline char *getDataByInternalId(tableint internal_id) const {
152
- return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
220
+ dist_t lowerBound;
221
+ if (!isMarkedDeleted(ep_id)) {
222
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
223
+ top_candidates.emplace(dist, ep_id);
224
+ lowerBound = dist;
225
+ candidateSet.emplace(-dist, ep_id);
226
+ } else {
227
+ lowerBound = std::numeric_limits<dist_t>::max();
228
+ candidateSet.emplace(-lowerBound, ep_id);
153
229
  }
230
+ visited_array[ep_id] = visited_array_tag;
154
231
 
155
- int getRandomLevel(double reverse_size) {
156
- std::uniform_real_distribution<double> distribution(0.0, 1.0);
157
- double r = -log(distribution(level_generator_)) * reverse_size;
158
- return (int) r;
159
- }
160
-
161
-
162
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
163
- searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
164
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
165
- vl_type *visited_array = vl->mass;
166
- vl_type visited_array_tag = vl->curV;
167
-
168
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
169
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
170
-
171
- dist_t lowerBound;
172
- if (!isMarkedDeleted(ep_id)) {
173
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
174
- top_candidates.emplace(dist, ep_id);
175
- lowerBound = dist;
176
- candidateSet.emplace(-dist, ep_id);
177
- } else {
178
- lowerBound = std::numeric_limits<dist_t>::max();
179
- candidateSet.emplace(-lowerBound, ep_id);
232
+ while (!candidateSet.empty()) {
233
+ std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
234
+ if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
235
+ break;
180
236
  }
181
- visited_array[ep_id] = visited_array_tag;
182
-
183
- while (!candidateSet.empty()) {
184
- std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
185
- if ((-curr_el_pair.first) > lowerBound) {
186
- break;
187
- }
188
- candidateSet.pop();
237
+ candidateSet.pop();
189
238
 
190
- tableint curNodeNum = curr_el_pair.second;
239
+ tableint curNodeNum = curr_el_pair.second;
191
240
 
192
- std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
241
+ std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
193
242
 
194
- int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
195
- if (layer == 0) {
196
- data = (int*)get_linklist0(curNodeNum);
197
- } else {
198
- data = (int*)get_linklist(curNodeNum, layer);
243
+ int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
244
+ if (layer == 0) {
245
+ data = (int*)get_linklist0(curNodeNum);
246
+ } else {
247
+ data = (int*)get_linklist(curNodeNum, layer);
199
248
  // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
200
- }
201
- size_t size = getListCount((linklistsizeint*)data);
202
- tableint *datal = (tableint *) (data + 1);
249
+ }
250
+ size_t size = getListCount((linklistsizeint*)data);
251
+ tableint *datal = (tableint *) (data + 1);
203
252
  #ifdef USE_SSE
204
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
205
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
206
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
207
- _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
253
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
254
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
255
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
256
+ _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
208
257
  #endif
209
258
 
210
- for (size_t j = 0; j < size; j++) {
211
- tableint candidate_id = *(datal + j);
259
+ for (size_t j = 0; j < size; j++) {
260
+ tableint candidate_id = *(datal + j);
212
261
  // if (candidate_id == 0) continue;
213
262
  #ifdef USE_SSE
214
- _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
215
- _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
263
+ _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
264
+ _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
216
265
  #endif
217
- if (visited_array[candidate_id] == visited_array_tag) continue;
218
- visited_array[candidate_id] = visited_array_tag;
219
- char *currObj1 = (getDataByInternalId(candidate_id));
266
+ if (visited_array[candidate_id] == visited_array_tag) continue;
267
+ visited_array[candidate_id] = visited_array_tag;
268
+ char *currObj1 = (getDataByInternalId(candidate_id));
220
269
 
221
- dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
222
- if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
223
- candidateSet.emplace(-dist1, candidate_id);
270
+ dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
271
+ if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
272
+ candidateSet.emplace(-dist1, candidate_id);
224
273
  #ifdef USE_SSE
225
- _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
274
+ _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
226
275
  #endif
227
276
 
228
- if (!isMarkedDeleted(candidate_id))
229
- top_candidates.emplace(dist1, candidate_id);
277
+ if (!isMarkedDeleted(candidate_id))
278
+ top_candidates.emplace(dist1, candidate_id);
230
279
 
231
- if (top_candidates.size() > ef_construction_)
232
- top_candidates.pop();
280
+ if (top_candidates.size() > ef_construction_)
281
+ top_candidates.pop();
233
282
 
234
- if (!top_candidates.empty())
235
- lowerBound = top_candidates.top().first;
236
- }
283
+ if (!top_candidates.empty())
284
+ lowerBound = top_candidates.top().first;
237
285
  }
238
286
  }
239
- visited_list_pool_->releaseVisitedList(vl);
240
-
241
- return top_candidates;
287
+ }
288
+ visited_list_pool_->releaseVisitedList(vl);
289
+
290
+ return top_candidates;
291
+ }
292
+
293
+
294
+ template <bool has_deletions, bool collect_metrics = false>
295
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
296
+ searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
297
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
298
+ vl_type *visited_array = vl->mass;
299
+ vl_type visited_array_tag = vl->curV;
300
+
301
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
302
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
303
+
304
+ dist_t lowerBound;
305
+ if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
306
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
307
+ lowerBound = dist;
308
+ top_candidates.emplace(dist, ep_id);
309
+ candidate_set.emplace(-dist, ep_id);
310
+ } else {
311
+ lowerBound = std::numeric_limits<dist_t>::max();
312
+ candidate_set.emplace(-lowerBound, ep_id);
242
313
  }
243
314
 
244
- mutable std::atomic<long> metric_distance_computations;
245
- mutable std::atomic<long> metric_hops;
246
-
247
- template <bool has_deletions, bool collect_metrics=false>
248
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
249
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
250
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
251
- vl_type *visited_array = vl->mass;
252
- vl_type visited_array_tag = vl->curV;
253
-
254
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
255
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
256
-
257
- dist_t lowerBound;
258
- if (!has_deletions || !isMarkedDeleted(ep_id)) {
259
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
260
- lowerBound = dist;
261
- top_candidates.emplace(dist, ep_id);
262
- candidate_set.emplace(-dist, ep_id);
263
- } else {
264
- lowerBound = std::numeric_limits<dist_t>::max();
265
- candidate_set.emplace(-lowerBound, ep_id);
266
- }
267
-
268
- visited_array[ep_id] = visited_array_tag;
269
-
270
- while (!candidate_set.empty()) {
315
+ visited_array[ep_id] = visited_array_tag;
271
316
 
272
- std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
317
+ while (!candidate_set.empty()) {
318
+ std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
273
319
 
274
- if ((-current_node_pair.first) > lowerBound) {
275
- break;
276
- }
277
- candidate_set.pop();
320
+ if ((-current_node_pair.first) > lowerBound &&
321
+ (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
322
+ break;
323
+ }
324
+ candidate_set.pop();
278
325
 
279
- tableint current_node_id = current_node_pair.second;
280
- int *data = (int *) get_linklist0(current_node_id);
281
- size_t size = getListCount((linklistsizeint*)data);
326
+ tableint current_node_id = current_node_pair.second;
327
+ int *data = (int *) get_linklist0(current_node_id);
328
+ size_t size = getListCount((linklistsizeint*)data);
282
329
  // bool cur_node_deleted = isMarkedDeleted(current_node_id);
283
- if(collect_metrics){
284
- metric_hops++;
285
- metric_distance_computations+=size;
286
- }
330
+ if (collect_metrics) {
331
+ metric_hops++;
332
+ metric_distance_computations+=size;
333
+ }
287
334
 
288
335
  #ifdef USE_SSE
289
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
290
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
291
- _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
292
- _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
336
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
337
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
338
+ _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
339
+ _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
293
340
  #endif
294
341
 
295
- for (size_t j = 1; j <= size; j++) {
296
- int candidate_id = *(data + j);
342
+ for (size_t j = 1; j <= size; j++) {
343
+ int candidate_id = *(data + j);
297
344
  // if (candidate_id == 0) continue;
298
345
  #ifdef USE_SSE
299
- _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
300
- _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
301
- _MM_HINT_T0);////////////
346
+ _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
347
+ _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
348
+ _MM_HINT_T0); ////////////
302
349
  #endif
303
- if (!(visited_array[candidate_id] == visited_array_tag)) {
304
-
305
- visited_array[candidate_id] = visited_array_tag;
350
+ if (!(visited_array[candidate_id] == visited_array_tag)) {
351
+ visited_array[candidate_id] = visited_array_tag;
306
352
 
307
- char *currObj1 = (getDataByInternalId(candidate_id));
308
- dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
353
+ char *currObj1 = (getDataByInternalId(candidate_id));
354
+ dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
309
355
 
310
- if (top_candidates.size() < ef || lowerBound > dist) {
311
- candidate_set.emplace(-dist, candidate_id);
356
+ if (top_candidates.size() < ef || lowerBound > dist) {
357
+ candidate_set.emplace(-dist, candidate_id);
312
358
  #ifdef USE_SSE
313
- _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
314
- offsetLevel0_,///////////
315
- _MM_HINT_T0);////////////////////////
359
+ _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
360
+ offsetLevel0_, ///////////
361
+ _MM_HINT_T0); ////////////////////////
316
362
  #endif
317
363
 
318
- if (!has_deletions || !isMarkedDeleted(candidate_id))
319
- top_candidates.emplace(dist, candidate_id);
364
+ if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
365
+ top_candidates.emplace(dist, candidate_id);
320
366
 
321
- if (top_candidates.size() > ef)
322
- top_candidates.pop();
367
+ if (top_candidates.size() > ef)
368
+ top_candidates.pop();
323
369
 
324
- if (!top_candidates.empty())
325
- lowerBound = top_candidates.top().first;
326
- }
370
+ if (!top_candidates.empty())
371
+ lowerBound = top_candidates.top().first;
327
372
  }
328
373
  }
329
374
  }
330
-
331
- visited_list_pool_->releaseVisitedList(vl);
332
- return top_candidates;
333
375
  }
334
376
 
335
- void getNeighborsByHeuristic2(
336
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
337
- const size_t M) {
338
- if (top_candidates.size() < M) {
339
- return;
340
- }
377
+ visited_list_pool_->releaseVisitedList(vl);
378
+ return top_candidates;
379
+ }
341
380
 
342
- std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
343
- std::vector<std::pair<dist_t, tableint>> return_list;
344
- while (top_candidates.size() > 0) {
345
- queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
346
- top_candidates.pop();
347
- }
348
381
 
349
- while (queue_closest.size()) {
350
- if (return_list.size() >= M)
382
+ void getNeighborsByHeuristic2(
383
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
384
+ const size_t M) {
385
+ if (top_candidates.size() < M) {
386
+ return;
387
+ }
388
+
389
+ std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
390
+ std::vector<std::pair<dist_t, tableint>> return_list;
391
+ while (top_candidates.size() > 0) {
392
+ queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
393
+ top_candidates.pop();
394
+ }
395
+
396
+ while (queue_closest.size()) {
397
+ if (return_list.size() >= M)
398
+ break;
399
+ std::pair<dist_t, tableint> curent_pair = queue_closest.top();
400
+ dist_t dist_to_query = -curent_pair.first;
401
+ queue_closest.pop();
402
+ bool good = true;
403
+
404
+ for (std::pair<dist_t, tableint> second_pair : return_list) {
405
+ dist_t curdist =
406
+ fstdistfunc_(getDataByInternalId(second_pair.second),
407
+ getDataByInternalId(curent_pair.second),
408
+ dist_func_param_);
409
+ if (curdist < dist_to_query) {
410
+ good = false;
351
411
  break;
352
- std::pair<dist_t, tableint> curent_pair = queue_closest.top();
353
- dist_t dist_to_query = -curent_pair.first;
354
- queue_closest.pop();
355
- bool good = true;
356
-
357
- for (std::pair<dist_t, tableint> second_pair : return_list) {
358
- dist_t curdist =
359
- fstdistfunc_(getDataByInternalId(second_pair.second),
360
- getDataByInternalId(curent_pair.second),
361
- dist_func_param_);;
362
- if (curdist < dist_to_query) {
363
- good = false;
364
- break;
365
- }
366
- }
367
- if (good) {
368
- return_list.push_back(curent_pair);
369
412
  }
370
413
  }
371
-
372
- for (std::pair<dist_t, tableint> curent_pair : return_list) {
373
- top_candidates.emplace(-curent_pair.first, curent_pair.second);
414
+ if (good) {
415
+ return_list.push_back(curent_pair);
374
416
  }
375
417
  }
376
418
 
419
+ for (std::pair<dist_t, tableint> curent_pair : return_list) {
420
+ top_candidates.emplace(-curent_pair.first, curent_pair.second);
421
+ }
422
+ }
377
423
 
378
- linklistsizeint *get_linklist0(tableint internal_id) const {
379
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
380
- };
381
424
 
382
- linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
383
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
384
- };
425
+ linklistsizeint *get_linklist0(tableint internal_id) const {
426
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
427
+ }
385
428
 
386
- linklistsizeint *get_linklist(tableint internal_id, int level) const {
387
- return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
388
- };
389
429
 
390
- linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
391
- return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
392
- };
430
+ linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
431
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
432
+ }
393
433
 
394
- tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
395
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
396
- int level, bool isUpdate) {
397
- size_t Mcurmax = level ? maxM_ : maxM0_;
398
- getNeighborsByHeuristic2(top_candidates, M_);
399
- if (top_candidates.size() > M_)
400
- throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
401
434
 
402
- std::vector<tableint> selectedNeighbors;
403
- selectedNeighbors.reserve(M_);
404
- while (top_candidates.size() > 0) {
405
- selectedNeighbors.push_back(top_candidates.top().second);
406
- top_candidates.pop();
407
- }
435
+ linklistsizeint *get_linklist(tableint internal_id, int level) const {
436
+ return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
437
+ }
408
438
 
409
- tableint next_closest_entry_point = selectedNeighbors.back();
410
439
 
411
- {
412
- linklistsizeint *ll_cur;
413
- if (level == 0)
414
- ll_cur = get_linklist0(cur_c);
415
- else
416
- ll_cur = get_linklist(cur_c, level);
440
+ linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
441
+ return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
442
+ }
417
443
 
418
- if (*ll_cur && !isUpdate) {
419
- throw std::runtime_error("The newly inserted element should have blank link list");
420
- }
421
- setListCount(ll_cur,selectedNeighbors.size());
422
- tableint *data = (tableint *) (ll_cur + 1);
423
- for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
424
- if (data[idx] && !isUpdate)
425
- throw std::runtime_error("Possible memory corruption");
426
- if (level > element_levels_[selectedNeighbors[idx]])
427
- throw std::runtime_error("Trying to make a link on a non-existent level");
428
444
 
429
- data[idx] = selectedNeighbors[idx];
445
+ tableint mutuallyConnectNewElement(
446
+ const void *data_point,
447
+ tableint cur_c,
448
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
449
+ int level,
450
+ bool isUpdate) {
451
+ size_t Mcurmax = level ? maxM_ : maxM0_;
452
+ getNeighborsByHeuristic2(top_candidates, M_);
453
+ if (top_candidates.size() > M_)
454
+ throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
430
455
 
431
- }
432
- }
456
+ std::vector<tableint> selectedNeighbors;
457
+ selectedNeighbors.reserve(M_);
458
+ while (top_candidates.size() > 0) {
459
+ selectedNeighbors.push_back(top_candidates.top().second);
460
+ top_candidates.pop();
461
+ }
462
+
463
+ tableint next_closest_entry_point = selectedNeighbors.back();
433
464
 
465
+ {
466
+ // lock only during the update
467
+ // because during the addition the lock for cur_c is already acquired
468
+ std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
469
+ if (isUpdate) {
470
+ lock.lock();
471
+ }
472
+ linklistsizeint *ll_cur;
473
+ if (level == 0)
474
+ ll_cur = get_linklist0(cur_c);
475
+ else
476
+ ll_cur = get_linklist(cur_c, level);
477
+
478
+ if (*ll_cur && !isUpdate) {
479
+ throw std::runtime_error("The newly inserted element should have blank link list");
480
+ }
481
+ setListCount(ll_cur, selectedNeighbors.size());
482
+ tableint *data = (tableint *) (ll_cur + 1);
434
483
  for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
484
+ if (data[idx] && !isUpdate)
485
+ throw std::runtime_error("Possible memory corruption");
486
+ if (level > element_levels_[selectedNeighbors[idx]])
487
+ throw std::runtime_error("Trying to make a link on a non-existent level");
435
488
 
436
- std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
489
+ data[idx] = selectedNeighbors[idx];
490
+ }
491
+ }
437
492
 
438
- linklistsizeint *ll_other;
439
- if (level == 0)
440
- ll_other = get_linklist0(selectedNeighbors[idx]);
441
- else
442
- ll_other = get_linklist(selectedNeighbors[idx], level);
493
+ for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
494
+ std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
443
495
 
444
- size_t sz_link_list_other = getListCount(ll_other);
496
+ linklistsizeint *ll_other;
497
+ if (level == 0)
498
+ ll_other = get_linklist0(selectedNeighbors[idx]);
499
+ else
500
+ ll_other = get_linklist(selectedNeighbors[idx], level);
445
501
 
446
- if (sz_link_list_other > Mcurmax)
447
- throw std::runtime_error("Bad value of sz_link_list_other");
448
- if (selectedNeighbors[idx] == cur_c)
449
- throw std::runtime_error("Trying to connect an element to itself");
450
- if (level > element_levels_[selectedNeighbors[idx]])
451
- throw std::runtime_error("Trying to make a link on a non-existent level");
502
+ size_t sz_link_list_other = getListCount(ll_other);
452
503
 
453
- tableint *data = (tableint *) (ll_other + 1);
504
+ if (sz_link_list_other > Mcurmax)
505
+ throw std::runtime_error("Bad value of sz_link_list_other");
506
+ if (selectedNeighbors[idx] == cur_c)
507
+ throw std::runtime_error("Trying to connect an element to itself");
508
+ if (level > element_levels_[selectedNeighbors[idx]])
509
+ throw std::runtime_error("Trying to make a link on a non-existent level");
454
510
 
455
- bool is_cur_c_present = false;
456
- if (isUpdate) {
457
- for (size_t j = 0; j < sz_link_list_other; j++) {
458
- if (data[j] == cur_c) {
459
- is_cur_c_present = true;
460
- break;
461
- }
511
+ tableint *data = (tableint *) (ll_other + 1);
512
+
513
+ bool is_cur_c_present = false;
514
+ if (isUpdate) {
515
+ for (size_t j = 0; j < sz_link_list_other; j++) {
516
+ if (data[j] == cur_c) {
517
+ is_cur_c_present = true;
518
+ break;
462
519
  }
463
520
  }
521
+ }
464
522
 
465
- // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
466
- if (!is_cur_c_present) {
467
- if (sz_link_list_other < Mcurmax) {
468
- data[sz_link_list_other] = cur_c;
469
- setListCount(ll_other, sz_link_list_other + 1);
470
- } else {
471
- // finding the "weakest" element to replace it with the new one
472
- dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
473
- dist_func_param_);
474
- // Heuristic:
475
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
476
- candidates.emplace(d_max, cur_c);
477
-
478
- for (size_t j = 0; j < sz_link_list_other; j++) {
479
- candidates.emplace(
480
- fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
481
- dist_func_param_), data[j]);
482
- }
523
+ // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
524
+ if (!is_cur_c_present) {
525
+ if (sz_link_list_other < Mcurmax) {
526
+ data[sz_link_list_other] = cur_c;
527
+ setListCount(ll_other, sz_link_list_other + 1);
528
+ } else {
529
+ // finding the "weakest" element to replace it with the new one
530
+ dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
531
+ dist_func_param_);
532
+ // Heuristic:
533
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
534
+ candidates.emplace(d_max, cur_c);
483
535
 
484
- getNeighborsByHeuristic2(candidates, Mcurmax);
536
+ for (size_t j = 0; j < sz_link_list_other; j++) {
537
+ candidates.emplace(
538
+ fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
539
+ dist_func_param_), data[j]);
540
+ }
485
541
 
486
- int indx = 0;
487
- while (candidates.size() > 0) {
488
- data[indx] = candidates.top().second;
489
- candidates.pop();
490
- indx++;
491
- }
542
+ getNeighborsByHeuristic2(candidates, Mcurmax);
492
543
 
493
- setListCount(ll_other, indx);
494
- // Nearest K:
495
- /*int indx = -1;
496
- for (int j = 0; j < sz_link_list_other; j++) {
497
- dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
498
- if (d > d_max) {
499
- indx = j;
500
- d_max = d;
501
- }
544
+ int indx = 0;
545
+ while (candidates.size() > 0) {
546
+ data[indx] = candidates.top().second;
547
+ candidates.pop();
548
+ indx++;
549
+ }
550
+
551
+ setListCount(ll_other, indx);
552
+ // Nearest K:
553
+ /*int indx = -1;
554
+ for (int j = 0; j < sz_link_list_other; j++) {
555
+ dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
556
+ if (d > d_max) {
557
+ indx = j;
558
+ d_max = d;
502
559
  }
503
- if (indx >= 0) {
504
- data[indx] = cur_c;
505
- } */
506
560
  }
561
+ if (indx >= 0) {
562
+ data[indx] = cur_c;
563
+ } */
507
564
  }
508
565
  }
509
-
510
- return next_closest_entry_point;
511
566
  }
512
567
 
513
- std::mutex global;
514
- size_t ef_;
515
-
516
- void setEf(size_t ef) {
517
- ef_ = ef;
518
- }
568
+ return next_closest_entry_point;
569
+ }
519
570
 
520
571
 
521
- std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
522
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
523
- if (cur_element_count == 0) return top_candidates;
524
- tableint currObj = enterpoint_node_;
525
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
526
-
527
- for (size_t level = maxlevel_; level > 0; level--) {
528
- bool changed = true;
529
- while (changed) {
530
- changed = false;
531
- int *data;
532
- data = (int *) get_linklist(currObj,level);
533
- int size = getListCount(data);
534
- tableint *datal = (tableint *) (data + 1);
535
- for (int i = 0; i < size; i++) {
536
- tableint cand = datal[i];
537
- if (cand < 0 || cand > max_elements_)
538
- throw std::runtime_error("cand error");
539
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
572
+ void resizeIndex(size_t new_max_elements) {
573
+ if (new_max_elements < cur_element_count)
574
+ throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
540
575
 
541
- if (d < curdist) {
542
- curdist = d;
543
- currObj = cand;
544
- changed = true;
545
- }
546
- }
547
- }
548
- }
576
+ delete visited_list_pool_;
577
+ visited_list_pool_ = new VisitedListPool(1, new_max_elements);
549
578
 
550
- if (has_deletions_) {
551
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
552
- ef_);
553
- top_candidates.swap(top_candidates1);
554
- }
555
- else{
556
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
557
- ef_);
558
- top_candidates.swap(top_candidates1);
559
- }
579
+ element_levels_.resize(new_max_elements);
560
580
 
561
- while (top_candidates.size() > k) {
562
- top_candidates.pop();
563
- }
564
- return top_candidates;
565
- };
581
+ std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
566
582
 
567
- void resizeIndex(size_t new_max_elements){
568
- if (new_max_elements<cur_element_count)
569
- throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
583
+ // Reallocate base layer
584
+ char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
585
+ if (data_level0_memory_new == nullptr)
586
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
587
+ data_level0_memory_ = data_level0_memory_new;
570
588
 
589
+ // Reallocate all other layers
590
+ char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
591
+ if (linkLists_new == nullptr)
592
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
593
+ linkLists_ = linkLists_new;
571
594
 
572
- delete visited_list_pool_;
573
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
595
+ max_elements_ = new_max_elements;
596
+ }
574
597
 
575
598
 
576
- element_levels_.resize(new_max_elements);
599
+ void saveIndex(const std::string &location) {
600
+ std::ofstream output(location, std::ios::binary);
601
+ std::streampos position;
577
602
 
578
- std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
603
+ writeBinaryPOD(output, offsetLevel0_);
604
+ writeBinaryPOD(output, max_elements_);
605
+ writeBinaryPOD(output, cur_element_count);
606
+ writeBinaryPOD(output, size_data_per_element_);
607
+ writeBinaryPOD(output, label_offset_);
608
+ writeBinaryPOD(output, offsetData_);
609
+ writeBinaryPOD(output, maxlevel_);
610
+ writeBinaryPOD(output, enterpoint_node_);
611
+ writeBinaryPOD(output, maxM_);
579
612
 
580
- // Reallocate base layer
581
- char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
582
- if (data_level0_memory_new == nullptr)
583
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
584
- data_level0_memory_ = data_level0_memory_new;
613
+ writeBinaryPOD(output, maxM0_);
614
+ writeBinaryPOD(output, M_);
615
+ writeBinaryPOD(output, mult_);
616
+ writeBinaryPOD(output, ef_construction_);
585
617
 
586
- // Reallocate all other layers
587
- char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
588
- if (linkLists_new == nullptr)
589
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
590
- linkLists_ = linkLists_new;
618
+ output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
591
619
 
592
- max_elements_ = new_max_elements;
620
+ for (size_t i = 0; i < cur_element_count; i++) {
621
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
622
+ writeBinaryPOD(output, linkListSize);
623
+ if (linkListSize)
624
+ output.write(linkLists_[i], linkListSize);
593
625
  }
626
+ output.close();
627
+ }
628
+
629
+
630
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0) {
631
+ std::ifstream input(location, std::ios::binary);
632
+
633
+ if (!input.is_open())
634
+ throw std::runtime_error("Cannot open file");
635
+
636
+ // get file size:
637
+ input.seekg(0, input.end);
638
+ std::streampos total_filesize = input.tellg();
639
+ input.seekg(0, input.beg);
640
+
641
+ readBinaryPOD(input, offsetLevel0_);
642
+ readBinaryPOD(input, max_elements_);
643
+ readBinaryPOD(input, cur_element_count);
644
+
645
+ size_t max_elements = max_elements_i;
646
+ if (max_elements < cur_element_count)
647
+ max_elements = max_elements_;
648
+ max_elements_ = max_elements;
649
+ readBinaryPOD(input, size_data_per_element_);
650
+ readBinaryPOD(input, label_offset_);
651
+ readBinaryPOD(input, offsetData_);
652
+ readBinaryPOD(input, maxlevel_);
653
+ readBinaryPOD(input, enterpoint_node_);
654
+
655
+ readBinaryPOD(input, maxM_);
656
+ readBinaryPOD(input, maxM0_);
657
+ readBinaryPOD(input, M_);
658
+ readBinaryPOD(input, mult_);
659
+ readBinaryPOD(input, ef_construction_);
660
+
661
+ data_size_ = s->get_data_size();
662
+ fstdistfunc_ = s->get_dist_func();
663
+ dist_func_param_ = s->get_dist_func_param();
664
+
665
+ auto pos = input.tellg();
666
+
667
+ /// Optional - check if index is ok:
668
+ input.seekg(cur_element_count * size_data_per_element_, input.cur);
669
+ for (size_t i = 0; i < cur_element_count; i++) {
670
+ if (input.tellg() < 0 || input.tellg() >= total_filesize) {
671
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
672
+ }
594
673
 
595
- void saveIndex(const std::string &location) {
596
- std::ofstream output(location, std::ios::binary);
597
- std::streampos position;
598
-
599
- writeBinaryPOD(output, offsetLevel0_);
600
- writeBinaryPOD(output, max_elements_);
601
- writeBinaryPOD(output, cur_element_count);
602
- writeBinaryPOD(output, size_data_per_element_);
603
- writeBinaryPOD(output, label_offset_);
604
- writeBinaryPOD(output, offsetData_);
605
- writeBinaryPOD(output, maxlevel_);
606
- writeBinaryPOD(output, enterpoint_node_);
607
- writeBinaryPOD(output, maxM_);
608
-
609
- writeBinaryPOD(output, maxM0_);
610
- writeBinaryPOD(output, M_);
611
- writeBinaryPOD(output, mult_);
612
- writeBinaryPOD(output, ef_construction_);
613
-
614
- output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
615
-
616
- for (size_t i = 0; i < cur_element_count; i++) {
617
- unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
618
- writeBinaryPOD(output, linkListSize);
619
- if (linkListSize)
620
- output.write(linkLists_[i], linkListSize);
674
+ unsigned int linkListSize;
675
+ readBinaryPOD(input, linkListSize);
676
+ if (linkListSize != 0) {
677
+ input.seekg(linkListSize, input.cur);
621
678
  }
622
- output.close();
623
679
  }
624
680
 
625
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
626
-
627
-
628
- std::ifstream input(location, std::ios::binary);
629
-
630
- if (!input.is_open())
631
- throw std::runtime_error("Cannot open file");
681
+ // throw exception if it either corrupted or old index
682
+ if (input.tellg() != total_filesize)
683
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
684
+
685
+ input.clear();
686
+ /// Optional check end
687
+
688
+ input.seekg(pos, input.beg);
689
+
690
+ data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
691
+ if (data_level0_memory_ == nullptr)
692
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
693
+ input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
694
+
695
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
696
+
697
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
698
+ std::vector<std::mutex>(max_elements).swap(link_list_locks_);
699
+ std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
700
+
701
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
702
+
703
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
704
+ if (linkLists_ == nullptr)
705
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
706
+ element_levels_ = std::vector<int>(max_elements);
707
+ revSize_ = 1.0 / mult_;
708
+ ef_ = 10;
709
+ for (size_t i = 0; i < cur_element_count; i++) {
710
+ label_lookup_[getExternalLabel(i)] = i;
711
+ unsigned int linkListSize;
712
+ readBinaryPOD(input, linkListSize);
713
+ if (linkListSize == 0) {
714
+ element_levels_[i] = 0;
715
+ linkLists_[i] = nullptr;
716
+ } else {
717
+ element_levels_[i] = linkListSize / size_links_per_element_;
718
+ linkLists_[i] = (char *) malloc(linkListSize);
719
+ if (linkLists_[i] == nullptr)
720
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
721
+ input.read(linkLists_[i], linkListSize);
722
+ }
723
+ }
632
724
 
633
- // get file size:
634
- input.seekg(0,input.end);
635
- std::streampos total_filesize=input.tellg();
636
- input.seekg(0,input.beg);
725
+ for (size_t i = 0; i < cur_element_count; i++) {
726
+ if (isMarkedDeleted(i)) {
727
+ num_deleted_ += 1;
728
+ if (allow_replace_deleted_) deleted_elements.insert(i);
729
+ }
730
+ }
637
731
 
638
- readBinaryPOD(input, offsetLevel0_);
639
- readBinaryPOD(input, max_elements_);
640
- readBinaryPOD(input, cur_element_count);
732
+ input.close();
641
733
 
642
- size_t max_elements=max_elements_i;
643
- if(max_elements < cur_element_count)
644
- max_elements = max_elements_;
645
- max_elements_ = max_elements;
646
- readBinaryPOD(input, size_data_per_element_);
647
- readBinaryPOD(input, label_offset_);
648
- readBinaryPOD(input, offsetData_);
649
- readBinaryPOD(input, maxlevel_);
650
- readBinaryPOD(input, enterpoint_node_);
734
+ return;
735
+ }
651
736
 
652
- readBinaryPOD(input, maxM_);
653
- readBinaryPOD(input, maxM0_);
654
- readBinaryPOD(input, M_);
655
- readBinaryPOD(input, mult_);
656
- readBinaryPOD(input, ef_construction_);
657
737
 
738
+ template<typename data_t>
739
+ std::vector<data_t> getDataByLabel(labeltype label) const {
740
+ // lock all operations with element by label
741
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
742
+
743
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
744
+ auto search = label_lookup_.find(label);
745
+ if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
746
+ throw std::runtime_error("Label not found");
747
+ }
748
+ tableint internalId = search->second;
749
+ lock_table.unlock();
750
+
751
+ char* data_ptrv = getDataByInternalId(internalId);
752
+ size_t dim = *((size_t *) dist_func_param_);
753
+ std::vector<data_t> data;
754
+ data_t* data_ptr = (data_t*) data_ptrv;
755
+ for (int i = 0; i < dim; i++) {
756
+ data.push_back(*data_ptr);
757
+ data_ptr += 1;
758
+ }
759
+ return data;
760
+ }
658
761
 
659
- data_size_ = s->get_data_size();
660
- fstdistfunc_ = s->get_dist_func();
661
- dist_func_param_ = s->get_dist_func_param();
662
762
 
663
- auto pos=input.tellg();
763
+ /*
764
+ * Marks an element with the given label deleted, does NOT really change the current graph.
765
+ */
766
+ void markDelete(labeltype label) {
767
+ // lock all operations with element by label
768
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
664
769
 
770
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
771
+ auto search = label_lookup_.find(label);
772
+ if (search == label_lookup_.end()) {
773
+ throw std::runtime_error("Label not found");
774
+ }
775
+ tableint internalId = search->second;
776
+ lock_table.unlock();
665
777
 
666
- /// Optional - check if index is ok:
778
+ markDeletedInternal(internalId);
779
+ }
667
780
 
668
- input.seekg(cur_element_count * size_data_per_element_,input.cur);
669
- for (size_t i = 0; i < cur_element_count; i++) {
670
- if(input.tellg() < 0 || input.tellg()>=total_filesize){
671
- throw std::runtime_error("Index seems to be corrupted or unsupported");
672
- }
673
781
 
674
- unsigned int linkListSize;
675
- readBinaryPOD(input, linkListSize);
676
- if (linkListSize != 0) {
677
- input.seekg(linkListSize,input.cur);
678
- }
782
+ /*
783
+ * Uses the last 16 bits of the memory for the linked list size to store the mark,
784
+ * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases.
785
+ */
786
+ void markDeletedInternal(tableint internalId) {
787
+ assert(internalId < cur_element_count);
788
+ if (!isMarkedDeleted(internalId)) {
789
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
790
+ *ll_cur |= DELETE_MARK;
791
+ num_deleted_ += 1;
792
+ if (allow_replace_deleted_) {
793
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
794
+ deleted_elements.insert(internalId);
679
795
  }
796
+ } else {
797
+ throw std::runtime_error("The requested to delete element is already deleted");
798
+ }
799
+ }
800
+
801
+
802
+ /*
803
+ * Removes the deleted mark of the node, does NOT really change the current graph.
804
+ *
805
+ * Note: the method is not safe to use when replacement of deleted elements is enabled,
806
+ * because elements marked as deleted can be completely removed by addPoint
807
+ */
808
+ void unmarkDelete(labeltype label) {
809
+ // lock all operations with element by label
810
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
811
+
812
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
813
+ auto search = label_lookup_.find(label);
814
+ if (search == label_lookup_.end()) {
815
+ throw std::runtime_error("Label not found");
816
+ }
817
+ tableint internalId = search->second;
818
+ lock_table.unlock();
680
819
 
681
- // throw exception if it either corrupted or old index
682
- if(input.tellg()!=total_filesize)
683
- throw std::runtime_error("Index seems to be corrupted or unsupported");
684
-
685
- input.clear();
686
-
687
- /// Optional check end
688
-
689
- input.seekg(pos,input.beg);
690
-
691
-
692
- data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
693
- if (data_level0_memory_ == nullptr)
694
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
695
- input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
696
-
697
-
820
+ unmarkDeletedInternal(internalId);
821
+ }
698
822
 
699
823
 
700
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
701
824
 
825
+ /*
826
+ * Remove the deleted mark of the node.
827
+ */
828
+ void unmarkDeletedInternal(tableint internalId) {
829
+ assert(internalId < cur_element_count);
830
+ if (isMarkedDeleted(internalId)) {
831
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
832
+ *ll_cur &= ~DELETE_MARK;
833
+ num_deleted_ -= 1;
834
+ if (allow_replace_deleted_) {
835
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
836
+ deleted_elements.erase(internalId);
837
+ }
838
+ } else {
839
+ throw std::runtime_error("The requested to undelete element is not deleted");
840
+ }
841
+ }
702
842
 
703
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
704
- std::vector<std::mutex>(max_elements).swap(link_list_locks_);
705
- std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);
706
843
 
844
+ /*
845
+ * Checks the first 16 bits of the memory to see if the element is marked deleted.
846
+ */
847
+ bool isMarkedDeleted(tableint internalId) const {
848
+ unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2;
849
+ return *ll_cur & DELETE_MARK;
850
+ }
707
851
 
708
- visited_list_pool_ = new VisitedListPool(1, max_elements);
709
852
 
853
+ unsigned short int getListCount(linklistsizeint * ptr) const {
854
+ return *((unsigned short int *)ptr);
855
+ }
710
856
 
711
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
712
- if (linkLists_ == nullptr)
713
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
714
- element_levels_ = std::vector<int>(max_elements);
715
- revSize_ = 1.0 / mult_;
716
- ef_ = 10;
717
- for (size_t i = 0; i < cur_element_count; i++) {
718
- label_lookup_[getExternalLabel(i)]=i;
719
- unsigned int linkListSize;
720
- readBinaryPOD(input, linkListSize);
721
- if (linkListSize == 0) {
722
- element_levels_[i] = 0;
723
857
 
724
- linkLists_[i] = nullptr;
725
- } else {
726
- element_levels_[i] = linkListSize / size_links_per_element_;
727
- linkLists_[i] = (char *) malloc(linkListSize);
728
- if (linkLists_[i] == nullptr)
729
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
730
- input.read(linkLists_[i], linkListSize);
731
- }
732
- }
858
+ void setListCount(linklistsizeint * ptr, unsigned short int size) const {
859
+ *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
860
+ }
733
861
 
734
- has_deletions_=false;
735
862
 
736
- for (size_t i = 0; i < cur_element_count; i++) {
737
- if(isMarkedDeleted(i))
738
- has_deletions_=true;
739
- }
740
-
741
- input.close();
863
+ /*
864
+ * Adds point. Updates the point if it is already in the index.
865
+ * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
866
+ */
867
+ void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) {
868
+ if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
869
+ throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
870
+ }
742
871
 
872
+ // lock all operations with element by label
873
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
874
+ if (!replace_deleted) {
875
+ addPoint(data_point, label, -1);
743
876
  return;
744
877
  }
745
-
746
- template<typename data_t>
747
- std::vector<data_t> getDataByLabel(labeltype label) const
748
- {
749
- tableint label_c;
750
- auto search = label_lookup_.find(label);
751
- if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
752
- throw std::runtime_error("Label not found");
753
- }
754
- label_c = search->second;
755
-
756
- char* data_ptrv = getDataByInternalId(label_c);
757
- size_t dim = *((size_t *) dist_func_param_);
758
- std::vector<data_t> data;
759
- data_t* data_ptr = (data_t*) data_ptrv;
760
- for (int i = 0; i < dim; i++) {
761
- data.push_back(*data_ptr);
762
- data_ptr += 1;
763
- }
764
- return data;
878
+ // check if there is vacant place
879
+ tableint internal_id_replaced;
880
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
881
+ bool is_vacant_place = !deleted_elements.empty();
882
+ if (is_vacant_place) {
883
+ internal_id_replaced = *deleted_elements.begin();
884
+ deleted_elements.erase(internal_id_replaced);
765
885
  }
766
-
767
- static const unsigned char DELETE_MARK = 0x01;
768
- // static const unsigned char REUSE_MARK = 0x10;
769
- /**
770
- * Marks an element with the given label deleted, does NOT really change the current graph.
771
- * @param label
772
- */
773
- void markDelete(labeltype label)
774
- {
775
- has_deletions_=true;
776
- auto search = label_lookup_.find(label);
777
- if (search == label_lookup_.end()) {
778
- throw std::runtime_error("Label not found");
779
- }
780
- markDeletedInternal(search->second);
886
+ lock_deleted_elements.unlock();
887
+
888
+ // if there is no vacant place then add or update point
889
+ // else add point to vacant place
890
+ if (!is_vacant_place) {
891
+ addPoint(data_point, label, -1);
892
+ } else {
893
+ // we assume that there are no concurrent operations on deleted element
894
+ labeltype label_replaced = getExternalLabel(internal_id_replaced);
895
+ setExternalLabel(internal_id_replaced, label);
896
+
897
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
898
+ label_lookup_.erase(label_replaced);
899
+ label_lookup_[label] = internal_id_replaced;
900
+ lock_table.unlock();
901
+
902
+ unmarkDeletedInternal(internal_id_replaced);
903
+ updatePoint(data_point, internal_id_replaced, 1.0);
781
904
  }
905
+ }
782
906
 
783
- /**
784
- * Uses the first 8 bits of the memory for the linked list to store the mark,
785
- * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
786
- * @param internalId
787
- */
788
- void markDeletedInternal(tableint internalId) {
789
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
790
- *ll_cur |= DELETE_MARK;
791
- }
792
907
 
793
- /**
794
- * Remove the deleted mark of the node.
795
- * @param internalId
796
- */
797
- void unmarkDeletedInternal(tableint internalId) {
798
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
799
- *ll_cur &= ~DELETE_MARK;
800
- }
908
+ void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
909
+ // update the feature vector associated with existing point with new vector
910
+ memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
801
911
 
802
- /**
803
- * Checks the first 8 bits of the memory to see if the element is marked deleted.
804
- * @param internalId
805
- * @return
806
- */
807
- bool isMarkedDeleted(tableint internalId) const {
808
- unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
809
- return *ll_cur & DELETE_MARK;
810
- }
912
+ int maxLevelCopy = maxlevel_;
913
+ tableint entryPointCopy = enterpoint_node_;
914
+ // If point to be updated is entry point and graph just contains single element then just return.
915
+ if (entryPointCopy == internalId && cur_element_count == 1)
916
+ return;
811
917
 
812
- unsigned short int getListCount(linklistsizeint * ptr) const {
813
- return *((unsigned short int *)ptr);
814
- }
918
+ int elemLevel = element_levels_[internalId];
919
+ std::uniform_real_distribution<float> distribution(0.0, 1.0);
920
+ for (int layer = 0; layer <= elemLevel; layer++) {
921
+ std::unordered_set<tableint> sCand;
922
+ std::unordered_set<tableint> sNeigh;
923
+ std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
924
+ if (listOneHop.size() == 0)
925
+ continue;
815
926
 
816
- void setListCount(linklistsizeint * ptr, unsigned short int size) const {
817
- *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
818
- }
927
+ sCand.insert(internalId);
819
928
 
820
- void addPoint(const void *data_point, labeltype label) {
821
- addPoint(data_point, label,-1);
822
- }
929
+ for (auto&& elOneHop : listOneHop) {
930
+ sCand.insert(elOneHop);
823
931
 
824
- void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
825
- // update the feature vector associated with existing point with new vector
826
- memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
827
-
828
- int maxLevelCopy = maxlevel_;
829
- tableint entryPointCopy = enterpoint_node_;
830
- // If point to be updated is entry point and graph just contains single element then just return.
831
- if (entryPointCopy == internalId && cur_element_count == 1)
832
- return;
833
-
834
- int elemLevel = element_levels_[internalId];
835
- std::uniform_real_distribution<float> distribution(0.0, 1.0);
836
- for (int layer = 0; layer <= elemLevel; layer++) {
837
- std::unordered_set<tableint> sCand;
838
- std::unordered_set<tableint> sNeigh;
839
- std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
840
- if (listOneHop.size() == 0)
932
+ if (distribution(update_probability_generator_) > updateNeighborProbability)
841
933
  continue;
842
934
 
843
- sCand.insert(internalId);
844
-
845
- for (auto&& elOneHop : listOneHop) {
846
- sCand.insert(elOneHop);
935
+ sNeigh.insert(elOneHop);
847
936
 
848
- if (distribution(update_probability_generator_) > updateNeighborProbability)
849
- continue;
850
-
851
- sNeigh.insert(elOneHop);
852
-
853
- std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
854
- for (auto&& elTwoHop : listTwoHop) {
855
- sCand.insert(elTwoHop);
856
- }
937
+ std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
938
+ for (auto&& elTwoHop : listTwoHop) {
939
+ sCand.insert(elTwoHop);
857
940
  }
941
+ }
858
942
 
859
- for (auto&& neigh : sNeigh) {
860
- // if (neigh == internalId)
861
- // continue;
943
+ for (auto&& neigh : sNeigh) {
944
+ // if (neigh == internalId)
945
+ // continue;
862
946
 
863
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
864
- size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
865
- size_t elementsToKeep = std::min(ef_construction_, size);
866
- for (auto&& cand : sCand) {
867
- if (cand == neigh)
868
- continue;
869
-
870
- dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
871
- if (candidates.size() < elementsToKeep) {
947
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
948
+ size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
949
+ size_t elementsToKeep = std::min(ef_construction_, size);
950
+ for (auto&& cand : sCand) {
951
+ if (cand == neigh)
952
+ continue;
953
+
954
+ dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
955
+ if (candidates.size() < elementsToKeep) {
956
+ candidates.emplace(distance, cand);
957
+ } else {
958
+ if (distance < candidates.top().first) {
959
+ candidates.pop();
872
960
  candidates.emplace(distance, cand);
873
- } else {
874
- if (distance < candidates.top().first) {
875
- candidates.pop();
876
- candidates.emplace(distance, cand);
877
- }
878
961
  }
879
962
  }
963
+ }
880
964
 
881
- // Retrieve neighbours using heuristic and set connections.
882
- getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
883
-
884
- {
885
- std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
886
- linklistsizeint *ll_cur;
887
- ll_cur = get_linklist_at_level(neigh, layer);
888
- size_t candSize = candidates.size();
889
- setListCount(ll_cur, candSize);
890
- tableint *data = (tableint *) (ll_cur + 1);
891
- for (size_t idx = 0; idx < candSize; idx++) {
892
- data[idx] = candidates.top().second;
893
- candidates.pop();
894
- }
965
+ // Retrieve neighbours using heuristic and set connections.
966
+ getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
967
+
968
+ {
969
+ std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
970
+ linklistsizeint *ll_cur;
971
+ ll_cur = get_linklist_at_level(neigh, layer);
972
+ size_t candSize = candidates.size();
973
+ setListCount(ll_cur, candSize);
974
+ tableint *data = (tableint *) (ll_cur + 1);
975
+ for (size_t idx = 0; idx < candSize; idx++) {
976
+ data[idx] = candidates.top().second;
977
+ candidates.pop();
895
978
  }
896
979
  }
897
980
  }
981
+ }
898
982
 
899
- repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
900
- };
983
+ repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
984
+ }
901
985
 
902
- void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
903
- tableint currObj = entryPointInternalId;
904
- if (dataPointLevel < maxLevel) {
905
- dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
906
- for (int level = maxLevel; level > dataPointLevel; level--) {
907
- bool changed = true;
908
- while (changed) {
909
- changed = false;
910
- unsigned int *data;
911
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
912
- data = get_linklist_at_level(currObj,level);
913
- int size = getListCount(data);
914
- tableint *datal = (tableint *) (data + 1);
986
+
987
+ void repairConnectionsForUpdate(
988
+ const void *dataPoint,
989
+ tableint entryPointInternalId,
990
+ tableint dataPointInternalId,
991
+ int dataPointLevel,
992
+ int maxLevel) {
993
+ tableint currObj = entryPointInternalId;
994
+ if (dataPointLevel < maxLevel) {
995
+ dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
996
+ for (int level = maxLevel; level > dataPointLevel; level--) {
997
+ bool changed = true;
998
+ while (changed) {
999
+ changed = false;
1000
+ unsigned int *data;
1001
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1002
+ data = get_linklist_at_level(currObj, level);
1003
+ int size = getListCount(data);
1004
+ tableint *datal = (tableint *) (data + 1);
915
1005
  #ifdef USE_SSE
916
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
1006
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
917
1007
  #endif
918
- for (int i = 0; i < size; i++) {
1008
+ for (int i = 0; i < size; i++) {
919
1009
  #ifdef USE_SSE
920
- _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
1010
+ _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
921
1011
  #endif
922
- tableint cand = datal[i];
923
- dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
924
- if (d < curdist) {
925
- curdist = d;
926
- currObj = cand;
927
- changed = true;
928
- }
1012
+ tableint cand = datal[i];
1013
+ dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
1014
+ if (d < curdist) {
1015
+ curdist = d;
1016
+ currObj = cand;
1017
+ changed = true;
929
1018
  }
930
1019
  }
931
1020
  }
932
1021
  }
1022
+ }
933
1023
 
934
- if (dataPointLevel > maxLevel)
935
- throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
936
-
937
- for (int level = dataPointLevel; level >= 0; level--) {
938
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
939
- currObj, dataPoint, level);
1024
+ if (dataPointLevel > maxLevel)
1025
+ throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
940
1026
 
941
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
942
- while (topCandidates.size() > 0) {
943
- if (topCandidates.top().second != dataPointInternalId)
944
- filteredTopCandidates.push(topCandidates.top());
1027
+ for (int level = dataPointLevel; level >= 0; level--) {
1028
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
1029
+ currObj, dataPoint, level);
945
1030
 
946
- topCandidates.pop();
947
- }
1031
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
1032
+ while (topCandidates.size() > 0) {
1033
+ if (topCandidates.top().second != dataPointInternalId)
1034
+ filteredTopCandidates.push(topCandidates.top());
948
1035
 
949
- // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
950
- // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
951
- if (filteredTopCandidates.size() > 0) {
952
- bool epDeleted = isMarkedDeleted(entryPointInternalId);
953
- if (epDeleted) {
954
- filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
955
- if (filteredTopCandidates.size() > ef_construction_)
956
- filteredTopCandidates.pop();
957
- }
1036
+ topCandidates.pop();
1037
+ }
958
1038
 
959
- currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
1039
+ // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
1040
+ // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
1041
+ if (filteredTopCandidates.size() > 0) {
1042
+ bool epDeleted = isMarkedDeleted(entryPointInternalId);
1043
+ if (epDeleted) {
1044
+ filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
1045
+ if (filteredTopCandidates.size() > ef_construction_)
1046
+ filteredTopCandidates.pop();
960
1047
  }
1048
+
1049
+ currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
961
1050
  }
962
1051
  }
1052
+ }
1053
+
963
1054
 
964
- std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
965
- std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
966
- unsigned int *data = get_linklist_at_level(internalId, level);
967
- int size = getListCount(data);
968
- std::vector<tableint> result(size);
969
- tableint *ll = (tableint *) (data + 1);
970
- memcpy(result.data(), ll,size * sizeof(tableint));
971
- return result;
972
- };
973
-
974
- tableint addPoint(const void *data_point, labeltype label, int level) {
975
-
976
- tableint cur_c = 0;
977
- {
978
- // Checking if the element with the same label already exists
979
- // if so, updating it *instead* of creating a new element.
980
- std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
981
- auto search = label_lookup_.find(label);
982
- if (search != label_lookup_.end()) {
983
- tableint existingInternalId = search->second;
984
- templock_curr.unlock();
985
-
986
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
1055
+ std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
1056
+ std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
1057
+ unsigned int *data = get_linklist_at_level(internalId, level);
1058
+ int size = getListCount(data);
1059
+ std::vector<tableint> result(size);
1060
+ tableint *ll = (tableint *) (data + 1);
1061
+ memcpy(result.data(), ll, size * sizeof(tableint));
1062
+ return result;
1063
+ }
987
1064
 
1065
+
1066
+ tableint addPoint(const void *data_point, labeltype label, int level) {
1067
+ tableint cur_c = 0;
1068
+ {
1069
+ // Checking if the element with the same label already exists
1070
+ // if so, updating it *instead* of creating a new element.
1071
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
1072
+ auto search = label_lookup_.find(label);
1073
+ if (search != label_lookup_.end()) {
1074
+ tableint existingInternalId = search->second;
1075
+ if (allow_replace_deleted_) {
988
1076
  if (isMarkedDeleted(existingInternalId)) {
989
- unmarkDeletedInternal(existingInternalId);
1077
+ throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
990
1078
  }
991
- updatePoint(data_point, existingInternalId, 1.0);
992
-
993
- return existingInternalId;
994
1079
  }
1080
+ lock_table.unlock();
995
1081
 
996
- if (cur_element_count >= max_elements_) {
997
- throw std::runtime_error("The number of elements exceeds the specified limit");
998
- };
1082
+ if (isMarkedDeleted(existingInternalId)) {
1083
+ unmarkDeletedInternal(existingInternalId);
1084
+ }
1085
+ updatePoint(data_point, existingInternalId, 1.0);
999
1086
 
1000
- cur_c = cur_element_count;
1001
- cur_element_count++;
1002
- label_lookup_[label] = cur_c;
1087
+ return existingInternalId;
1003
1088
  }
1004
1089
 
1005
- // Take update lock to prevent race conditions on an element with insertion/update at the same time.
1006
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
1007
- std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1008
- int curlevel = getRandomLevel(mult_);
1009
- if (level > 0)
1010
- curlevel = level;
1090
+ if (cur_element_count >= max_elements_) {
1091
+ throw std::runtime_error("The number of elements exceeds the specified limit");
1092
+ }
1011
1093
 
1012
- element_levels_[cur_c] = curlevel;
1094
+ cur_c = cur_element_count;
1095
+ cur_element_count++;
1096
+ label_lookup_[label] = cur_c;
1097
+ }
1013
1098
 
1099
+ std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1100
+ int curlevel = getRandomLevel(mult_);
1101
+ if (level > 0)
1102
+ curlevel = level;
1014
1103
 
1015
- std::unique_lock <std::mutex> templock(global);
1016
- int maxlevelcopy = maxlevel_;
1017
- if (curlevel <= maxlevelcopy)
1018
- templock.unlock();
1019
- tableint currObj = enterpoint_node_;
1020
- tableint enterpoint_copy = enterpoint_node_;
1104
+ element_levels_[cur_c] = curlevel;
1021
1105
 
1106
+ std::unique_lock <std::mutex> templock(global);
1107
+ int maxlevelcopy = maxlevel_;
1108
+ if (curlevel <= maxlevelcopy)
1109
+ templock.unlock();
1110
+ tableint currObj = enterpoint_node_;
1111
+ tableint enterpoint_copy = enterpoint_node_;
1022
1112
 
1023
- memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1113
+ memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1024
1114
 
1025
- // Initialisation of the data and label
1026
- memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1027
- memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1115
+ // Initialisation of the data and label
1116
+ memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1117
+ memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1028
1118
 
1119
+ if (curlevel) {
1120
+ linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1121
+ if (linkLists_[cur_c] == nullptr)
1122
+ throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1123
+ memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1124
+ }
1029
1125
 
1030
- if (curlevel) {
1031
- linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1032
- if (linkLists_[cur_c] == nullptr)
1033
- throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1034
- memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1035
- }
1126
+ if ((signed)currObj != -1) {
1127
+ if (curlevel < maxlevelcopy) {
1128
+ dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1129
+ for (int level = maxlevelcopy; level > curlevel; level--) {
1130
+ bool changed = true;
1131
+ while (changed) {
1132
+ changed = false;
1133
+ unsigned int *data;
1134
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1135
+ data = get_linklist(currObj, level);
1136
+ int size = getListCount(data);
1036
1137
 
1037
- if ((signed)currObj != -1) {
1038
-
1039
- if (curlevel < maxlevelcopy) {
1040
-
1041
- dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1042
- for (int level = maxlevelcopy; level > curlevel; level--) {
1043
-
1044
-
1045
- bool changed = true;
1046
- while (changed) {
1047
- changed = false;
1048
- unsigned int *data;
1049
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1050
- data = get_linklist(currObj,level);
1051
- int size = getListCount(data);
1052
-
1053
- tableint *datal = (tableint *) (data + 1);
1054
- for (int i = 0; i < size; i++) {
1055
- tableint cand = datal[i];
1056
- if (cand < 0 || cand > max_elements_)
1057
- throw std::runtime_error("cand error");
1058
- dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1059
- if (d < curdist) {
1060
- curdist = d;
1061
- currObj = cand;
1062
- changed = true;
1063
- }
1138
+ tableint *datal = (tableint *) (data + 1);
1139
+ for (int i = 0; i < size; i++) {
1140
+ tableint cand = datal[i];
1141
+ if (cand < 0 || cand > max_elements_)
1142
+ throw std::runtime_error("cand error");
1143
+ dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1144
+ if (d < curdist) {
1145
+ curdist = d;
1146
+ currObj = cand;
1147
+ changed = true;
1064
1148
  }
1065
1149
  }
1066
1150
  }
1067
1151
  }
1152
+ }
1068
1153
 
1069
- bool epDeleted = isMarkedDeleted(enterpoint_copy);
1070
- for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1071
- if (level > maxlevelcopy || level < 0) // possible?
1072
- throw std::runtime_error("Level error");
1073
-
1074
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1075
- currObj, data_point, level);
1076
- if (epDeleted) {
1077
- top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1078
- if (top_candidates.size() > ef_construction_)
1079
- top_candidates.pop();
1080
- }
1081
- currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1154
+ bool epDeleted = isMarkedDeleted(enterpoint_copy);
1155
+ for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1156
+ if (level > maxlevelcopy || level < 0) // possible?
1157
+ throw std::runtime_error("Level error");
1158
+
1159
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1160
+ currObj, data_point, level);
1161
+ if (epDeleted) {
1162
+ top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1163
+ if (top_candidates.size() > ef_construction_)
1164
+ top_candidates.pop();
1082
1165
  }
1083
-
1084
-
1085
- } else {
1086
- // Do nothing for the first element
1087
- enterpoint_node_ = 0;
1088
- maxlevel_ = curlevel;
1089
-
1166
+ currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1090
1167
  }
1168
+ } else {
1169
+ // Do nothing for the first element
1170
+ enterpoint_node_ = 0;
1171
+ maxlevel_ = curlevel;
1172
+ }
1091
1173
 
1092
- //Releasing lock for the maximum level
1093
- if (curlevel > maxlevelcopy) {
1094
- enterpoint_node_ = cur_c;
1095
- maxlevel_ = curlevel;
1096
- }
1097
- return cur_c;
1098
- };
1174
+ // Releasing lock for the maximum level
1175
+ if (curlevel > maxlevelcopy) {
1176
+ enterpoint_node_ = cur_c;
1177
+ maxlevel_ = curlevel;
1178
+ }
1179
+ return cur_c;
1180
+ }
1099
1181
 
1100
- std::priority_queue<std::pair<dist_t, labeltype >>
1101
- searchKnn(const void *query_data, size_t k) const {
1102
- std::priority_queue<std::pair<dist_t, labeltype >> result;
1103
- if (cur_element_count == 0) return result;
1104
1182
 
1105
- tableint currObj = enterpoint_node_;
1106
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1183
+ std::priority_queue<std::pair<dist_t, labeltype >>
1184
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
1185
+ std::priority_queue<std::pair<dist_t, labeltype >> result;
1186
+ if (cur_element_count == 0) return result;
1107
1187
 
1108
- for (int level = maxlevel_; level > 0; level--) {
1109
- bool changed = true;
1110
- while (changed) {
1111
- changed = false;
1112
- unsigned int *data;
1188
+ tableint currObj = enterpoint_node_;
1189
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1113
1190
 
1114
- data = (unsigned int *) get_linklist(currObj, level);
1115
- int size = getListCount(data);
1116
- metric_hops++;
1117
- metric_distance_computations+=size;
1191
+ for (int level = maxlevel_; level > 0; level--) {
1192
+ bool changed = true;
1193
+ while (changed) {
1194
+ changed = false;
1195
+ unsigned int *data;
1118
1196
 
1119
- tableint *datal = (tableint *) (data + 1);
1120
- for (int i = 0; i < size; i++) {
1121
- tableint cand = datal[i];
1122
- if (cand < 0 || cand > max_elements_)
1123
- throw std::runtime_error("cand error");
1124
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1197
+ data = (unsigned int *) get_linklist(currObj, level);
1198
+ int size = getListCount(data);
1199
+ metric_hops++;
1200
+ metric_distance_computations+=size;
1125
1201
 
1126
- if (d < curdist) {
1127
- curdist = d;
1128
- currObj = cand;
1129
- changed = true;
1130
- }
1202
+ tableint *datal = (tableint *) (data + 1);
1203
+ for (int i = 0; i < size; i++) {
1204
+ tableint cand = datal[i];
1205
+ if (cand < 0 || cand > max_elements_)
1206
+ throw std::runtime_error("cand error");
1207
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1208
+
1209
+ if (d < curdist) {
1210
+ curdist = d;
1211
+ currObj = cand;
1212
+ changed = true;
1131
1213
  }
1132
1214
  }
1133
1215
  }
1216
+ }
1134
1217
 
1135
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1136
- if (has_deletions_) {
1137
- top_candidates=searchBaseLayerST<true,true>(
1138
- currObj, query_data, std::max(ef_, k));
1139
- }
1140
- else{
1141
- top_candidates=searchBaseLayerST<false,true>(
1142
- currObj, query_data, std::max(ef_, k));
1143
- }
1218
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1219
+ if (num_deleted_) {
1220
+ top_candidates = searchBaseLayerST<true, true>(
1221
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1222
+ } else {
1223
+ top_candidates = searchBaseLayerST<false, true>(
1224
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1225
+ }
1144
1226
 
1145
- while (top_candidates.size() > k) {
1146
- top_candidates.pop();
1147
- }
1148
- while (top_candidates.size() > 0) {
1149
- std::pair<dist_t, tableint> rez = top_candidates.top();
1150
- result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1151
- top_candidates.pop();
1152
- }
1153
- return result;
1154
- };
1155
-
1156
- void checkIntegrity(){
1157
- int connections_checked=0;
1158
- std::vector <int > inbound_connections_num(cur_element_count,0);
1159
- for(int i = 0;i < cur_element_count; i++){
1160
- for(int l = 0;l <= element_levels_[i]; l++){
1161
- linklistsizeint *ll_cur = get_linklist_at_level(i,l);
1162
- int size = getListCount(ll_cur);
1163
- tableint *data = (tableint *) (ll_cur + 1);
1164
- std::unordered_set<tableint> s;
1165
- for (int j=0; j<size; j++){
1166
- assert(data[j] > 0);
1167
- assert(data[j] < cur_element_count);
1168
- assert (data[j] != i);
1169
- inbound_connections_num[data[j]]++;
1170
- s.insert(data[j]);
1171
- connections_checked++;
1227
+ while (top_candidates.size() > k) {
1228
+ top_candidates.pop();
1229
+ }
1230
+ while (top_candidates.size() > 0) {
1231
+ std::pair<dist_t, tableint> rez = top_candidates.top();
1232
+ result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1233
+ top_candidates.pop();
1234
+ }
1235
+ return result;
1236
+ }
1172
1237
 
1173
- }
1174
- assert(s.size() == size);
1238
+
1239
+ void checkIntegrity() {
1240
+ int connections_checked = 0;
1241
+ std::vector <int > inbound_connections_num(cur_element_count, 0);
1242
+ for (int i = 0; i < cur_element_count; i++) {
1243
+ for (int l = 0; l <= element_levels_[i]; l++) {
1244
+ linklistsizeint *ll_cur = get_linklist_at_level(i, l);
1245
+ int size = getListCount(ll_cur);
1246
+ tableint *data = (tableint *) (ll_cur + 1);
1247
+ std::unordered_set<tableint> s;
1248
+ for (int j = 0; j < size; j++) {
1249
+ assert(data[j] > 0);
1250
+ assert(data[j] < cur_element_count);
1251
+ assert(data[j] != i);
1252
+ inbound_connections_num[data[j]]++;
1253
+ s.insert(data[j]);
1254
+ connections_checked++;
1175
1255
  }
1256
+ assert(s.size() == size);
1176
1257
  }
1177
- if(cur_element_count > 1){
1178
- int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
1179
- for(int i=0; i < cur_element_count; i++){
1180
- assert(inbound_connections_num[i] > 0);
1181
- min1=std::min(inbound_connections_num[i],min1);
1182
- max1=std::max(inbound_connections_num[i],max1);
1183
- }
1184
- std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1258
+ }
1259
+ if (cur_element_count > 1) {
1260
+ int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
1261
+ for (int i=0; i < cur_element_count; i++) {
1262
+ assert(inbound_connections_num[i] > 0);
1263
+ min1 = std::min(inbound_connections_num[i], min1);
1264
+ max1 = std::max(inbound_connections_num[i], max1);
1185
1265
  }
1186
- std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1187
-
1266
+ std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1188
1267
  }
1189
-
1190
- };
1191
-
1192
- }
1268
+ std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1269
+ }
1270
+ };
1271
+ } // namespace hnswlib