hnswlib 0.6.2 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,1199 +10,1263 @@
10
10
  #include <list>
11
11
 
12
12
  namespace hnswlib {
13
- typedef unsigned int tableint;
14
- typedef unsigned int linklistsizeint;
15
-
16
- template<typename dist_t>
17
- class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
- public:
19
- static const tableint max_update_element_locks = 65536;
20
- HierarchicalNSW() : visited_list_pool_(nullptr), data_level0_memory_(nullptr), linkLists_(nullptr), cur_element_count(0) { }
21
- HierarchicalNSW(SpaceInterface<dist_t> *s) {
22
- }
13
+ typedef unsigned int tableint;
14
+ typedef unsigned int linklistsizeint;
15
+
16
+ template<typename dist_t>
17
+ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
+ public:
19
+ static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
20
+ static const unsigned char DELETE_MARK = 0x01;
21
+
22
+ size_t max_elements_{0};
23
+ mutable std::atomic<size_t> cur_element_count{0}; // current number of elements
24
+ size_t size_data_per_element_{0};
25
+ size_t size_links_per_element_{0};
26
+ mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements
27
+ size_t M_{0};
28
+ size_t maxM_{0};
29
+ size_t maxM0_{0};
30
+ size_t ef_construction_{0};
31
+ size_t ef_{ 0 };
32
+
33
+ double mult_{0.0}, revSize_{0.0};
34
+ int maxlevel_{0};
35
+
36
+ VisitedListPool *visited_list_pool_{nullptr};
37
+
38
+ // Locks operations with element by label value
39
+ mutable std::vector<std::mutex> label_op_locks_;
40
+
41
+ std::mutex global;
42
+ std::vector<std::mutex> link_list_locks_;
43
+
44
+ tableint enterpoint_node_{0};
45
+
46
+ size_t size_links_level0_{0};
47
+ size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 };
48
+
49
+ char *data_level0_memory_{nullptr};
50
+ char **linkLists_{nullptr};
51
+ std::vector<int> element_levels_; // keeps level of each element
52
+
53
+ size_t data_size_{0};
54
+
55
+ DISTFUNC<dist_t> fstdistfunc_;
56
+ void *dist_func_param_{nullptr};
57
+
58
+ mutable std::mutex label_lookup_lock; // lock for label_lookup_
59
+ std::unordered_map<labeltype, tableint> label_lookup_;
60
+
61
+ std::default_random_engine level_generator_;
62
+ std::default_random_engine update_probability_generator_;
63
+
64
+ mutable std::atomic<long> metric_distance_computations{0};
65
+ mutable std::atomic<long> metric_hops{0};
66
+
67
+ bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions
68
+
69
+ std::mutex deleted_elements_lock; // lock for deleted_elements
70
+ std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
71
+
72
+ HierarchicalNSW() { }
73
+
74
+ HierarchicalNSW(SpaceInterface<dist_t> *s) {
75
+ }
76
+
77
+
78
+ HierarchicalNSW(
79
+ SpaceInterface<dist_t> *s,
80
+ const std::string &location,
81
+ bool nmslib = false,
82
+ size_t max_elements = 0,
83
+ bool allow_replace_deleted = false)
84
+ : allow_replace_deleted_(allow_replace_deleted) {
85
+ loadIndex(location, s, max_elements);
86
+ }
87
+
88
+
89
+ HierarchicalNSW(
90
+ SpaceInterface<dist_t> *s,
91
+ size_t max_elements,
92
+ size_t M = 16,
93
+ size_t ef_construction = 200,
94
+ size_t random_seed = 100,
95
+ bool allow_replace_deleted = false)
96
+ : link_list_locks_(max_elements),
97
+ label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
98
+ element_levels_(max_elements),
99
+ allow_replace_deleted_(allow_replace_deleted) {
100
+ max_elements_ = max_elements;
101
+ num_deleted_ = 0;
102
+ data_size_ = s->get_data_size();
103
+ fstdistfunc_ = s->get_dist_func();
104
+ dist_func_param_ = s->get_dist_func_param();
105
+ M_ = M;
106
+ maxM_ = M_;
107
+ maxM0_ = M_ * 2;
108
+ ef_construction_ = std::max(ef_construction, M_);
109
+ ef_ = 10;
110
+
111
+ level_generator_.seed(random_seed);
112
+ update_probability_generator_.seed(random_seed + 1);
113
+
114
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
115
+ size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
116
+ offsetData_ = size_links_level0_;
117
+ label_offset_ = size_links_level0_ + data_size_;
118
+ offsetLevel0_ = 0;
119
+
120
+ data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
121
+ if (data_level0_memory_ == nullptr)
122
+ throw std::runtime_error("Not enough memory");
123
+
124
+ cur_element_count = 0;
125
+
126
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
23
127
 
24
- HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
25
- loadIndex(location, s, max_elements);
26
- }
128
+ // initializations for special treatment of the first node
129
+ enterpoint_node_ = -1;
130
+ maxlevel_ = -1;
131
+
132
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
133
+ if (linkLists_ == nullptr)
134
+ throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
135
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
136
+ mult_ = 1 / log(1.0 * M_);
137
+ revSize_ = 1.0 / mult_;
138
+ }
27
139
 
28
- HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
29
- link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
30
- max_elements_ = max_elements;
31
-
32
- num_deleted_ = 0;
33
- data_size_ = s->get_data_size();
34
- fstdistfunc_ = s->get_dist_func();
35
- dist_func_param_ = s->get_dist_func_param();
36
- M_ = M;
37
- maxM_ = M_;
38
- maxM0_ = M_ * 2;
39
- ef_construction_ = std::max(ef_construction,M_);
40
- ef_ = 10;
41
-
42
- level_generator_.seed(random_seed);
43
- update_probability_generator_.seed(random_seed + 1);
44
-
45
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
46
- size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
47
- offsetData_ = size_links_level0_;
48
- label_offset_ = size_links_level0_ + data_size_;
49
- offsetLevel0_ = 0;
50
-
51
- data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
52
- if (data_level0_memory_ == nullptr)
53
- throw std::runtime_error("Not enough memory");
54
-
55
- cur_element_count = 0;
56
-
57
- visited_list_pool_ = new VisitedListPool(1, max_elements);
58
-
59
- //initializations for special treatment of the first node
60
- enterpoint_node_ = -1;
61
- maxlevel_ = -1;
62
-
63
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
64
- if (linkLists_ == nullptr)
65
- throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
66
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
67
- mult_ = 1 / log(1.0 * M_);
68
- revSize_ = 1.0 / mult_;
140
+
141
+ ~HierarchicalNSW() {
142
+ free(data_level0_memory_);
143
+ for (tableint i = 0; i < cur_element_count; i++) {
144
+ if (element_levels_[i] > 0)
145
+ free(linkLists_[i]);
69
146
  }
147
+ free(linkLists_);
148
+ delete visited_list_pool_;
149
+ }
70
150
 
71
- struct CompareByFirst {
72
- constexpr bool operator()(std::pair<dist_t, tableint> const &a,
73
- std::pair<dist_t, tableint> const &b) const noexcept {
74
- return a.first < b.first;
75
- }
76
- };
77
-
78
- ~HierarchicalNSW() {
79
- if (data_level0_memory_) free(data_level0_memory_);
80
- if (linkLists_) {
81
- for (tableint i = 0; i < cur_element_count; i++) {
82
- if (element_levels_[i] > 0)
83
- if (linkLists_[i]) free(linkLists_[i]);
84
- }
85
- free(linkLists_);
86
- }
87
- if (visited_list_pool_) delete visited_list_pool_;
151
+
152
+ struct CompareByFirst {
153
+ constexpr bool operator()(std::pair<dist_t, tableint> const& a,
154
+ std::pair<dist_t, tableint> const& b) const noexcept {
155
+ return a.first < b.first;
88
156
  }
157
+ };
89
158
 
90
- size_t max_elements_;
91
- size_t cur_element_count;
92
- size_t size_data_per_element_;
93
- size_t size_links_per_element_;
94
- size_t num_deleted_;
95
159
 
96
- size_t M_;
97
- size_t maxM_;
98
- size_t maxM0_;
99
- size_t ef_construction_;
160
+ void setEf(size_t ef) {
161
+ ef_ = ef;
162
+ }
100
163
 
101
- double mult_, revSize_;
102
- int maxlevel_;
103
164
 
165
+ inline std::mutex& getLabelOpMutex(labeltype label) const {
166
+ // calculate hash
167
+ size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
168
+ return label_op_locks_[lock_id];
169
+ }
104
170
 
105
- VisitedListPool *visited_list_pool_;
106
- std::mutex cur_element_count_guard_;
107
171
 
108
- std::vector<std::mutex> link_list_locks_;
172
+ inline labeltype getExternalLabel(tableint internal_id) const {
173
+ labeltype return_label;
174
+ memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
175
+ return return_label;
176
+ }
109
177
 
110
- // Locks to prevent race condition during update/insert of an element at same time.
111
- // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
112
- std::vector<std::mutex> link_list_update_locks_;
113
- tableint enterpoint_node_;
114
178
 
115
- size_t size_links_level0_;
116
- size_t offsetData_, offsetLevel0_;
179
+ inline void setExternalLabel(tableint internal_id, labeltype label) const {
180
+ memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
181
+ }
117
182
 
118
- char *data_level0_memory_;
119
- char **linkLists_;
120
- std::vector<int> element_levels_;
121
183
 
122
- size_t data_size_;
184
+ inline labeltype *getExternalLabeLp(tableint internal_id) const {
185
+ return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
186
+ }
123
187
 
124
- size_t label_offset_;
125
- DISTFUNC<dist_t> fstdistfunc_;
126
- void *dist_func_param_;
127
- std::unordered_map<labeltype, tableint> label_lookup_;
128
188
 
129
- std::default_random_engine level_generator_;
130
- std::default_random_engine update_probability_generator_;
189
+ inline char *getDataByInternalId(tableint internal_id) const {
190
+ return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
191
+ }
131
192
 
132
- inline labeltype getExternalLabel(tableint internal_id) const {
133
- labeltype return_label;
134
- memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
135
- return return_label;
136
- }
137
193
 
138
- inline void setExternalLabel(tableint internal_id, labeltype label) const {
139
- memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
140
- }
194
+ int getRandomLevel(double reverse_size) {
195
+ std::uniform_real_distribution<double> distribution(0.0, 1.0);
196
+ double r = -log(distribution(level_generator_)) * reverse_size;
197
+ return (int) r;
198
+ }
141
199
 
142
- inline labeltype *getExternalLabeLp(tableint internal_id) const {
143
- return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
144
- }
200
+ size_t getMaxElements() {
201
+ return max_elements_;
202
+ }
145
203
 
146
- inline char *getDataByInternalId(tableint internal_id) const {
147
- return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
148
- }
204
+ size_t getCurrentElementCount() {
205
+ return cur_element_count;
206
+ }
149
207
 
150
- int getRandomLevel(double reverse_size) {
151
- std::uniform_real_distribution<double> distribution(0.0, 1.0);
152
- double r = -log(distribution(level_generator_)) * reverse_size;
153
- return (int) r;
154
- }
208
+ size_t getDeletedCount() {
209
+ return num_deleted_;
210
+ }
155
211
 
212
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
213
+ searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
214
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
215
+ vl_type *visited_array = vl->mass;
216
+ vl_type visited_array_tag = vl->curV;
156
217
 
157
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
158
- searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
159
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
160
- vl_type *visited_array = vl->mass;
161
- vl_type visited_array_tag = vl->curV;
218
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
219
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
162
220
 
163
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
164
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
221
+ dist_t lowerBound;
222
+ if (!isMarkedDeleted(ep_id)) {
223
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
224
+ top_candidates.emplace(dist, ep_id);
225
+ lowerBound = dist;
226
+ candidateSet.emplace(-dist, ep_id);
227
+ } else {
228
+ lowerBound = std::numeric_limits<dist_t>::max();
229
+ candidateSet.emplace(-lowerBound, ep_id);
230
+ }
231
+ visited_array[ep_id] = visited_array_tag;
165
232
 
166
- dist_t lowerBound;
167
- if (!isMarkedDeleted(ep_id)) {
168
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
169
- top_candidates.emplace(dist, ep_id);
170
- lowerBound = dist;
171
- candidateSet.emplace(-dist, ep_id);
172
- } else {
173
- lowerBound = std::numeric_limits<dist_t>::max();
174
- candidateSet.emplace(-lowerBound, ep_id);
233
+ while (!candidateSet.empty()) {
234
+ std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
235
+ if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
236
+ break;
175
237
  }
176
- visited_array[ep_id] = visited_array_tag;
177
-
178
- while (!candidateSet.empty()) {
179
- std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
180
- if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
181
- break;
182
- }
183
- candidateSet.pop();
238
+ candidateSet.pop();
184
239
 
185
- tableint curNodeNum = curr_el_pair.second;
240
+ tableint curNodeNum = curr_el_pair.second;
186
241
 
187
- std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
242
+ std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
188
243
 
189
- int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
190
- if (layer == 0) {
191
- data = (int*)get_linklist0(curNodeNum);
192
- } else {
193
- data = (int*)get_linklist(curNodeNum, layer);
244
+ int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
245
+ if (layer == 0) {
246
+ data = (int*)get_linklist0(curNodeNum);
247
+ } else {
248
+ data = (int*)get_linklist(curNodeNum, layer);
194
249
  // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
195
- }
196
- size_t size = getListCount((linklistsizeint*)data);
197
- tableint *datal = (tableint *) (data + 1);
250
+ }
251
+ size_t size = getListCount((linklistsizeint*)data);
252
+ tableint *datal = (tableint *) (data + 1);
198
253
  #ifdef USE_SSE
199
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
200
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
201
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
202
- _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
254
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
255
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
256
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
257
+ _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
203
258
  #endif
204
259
 
205
- for (size_t j = 0; j < size; j++) {
206
- tableint candidate_id = *(datal + j);
260
+ for (size_t j = 0; j < size; j++) {
261
+ tableint candidate_id = *(datal + j);
207
262
  // if (candidate_id == 0) continue;
208
263
  #ifdef USE_SSE
209
- _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
210
- _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
264
+ _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
265
+ _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
211
266
  #endif
212
- if (visited_array[candidate_id] == visited_array_tag) continue;
213
- visited_array[candidate_id] = visited_array_tag;
214
- char *currObj1 = (getDataByInternalId(candidate_id));
267
+ if (visited_array[candidate_id] == visited_array_tag) continue;
268
+ visited_array[candidate_id] = visited_array_tag;
269
+ char *currObj1 = (getDataByInternalId(candidate_id));
215
270
 
216
- dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
217
- if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
218
- candidateSet.emplace(-dist1, candidate_id);
271
+ dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
272
+ if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
273
+ candidateSet.emplace(-dist1, candidate_id);
219
274
  #ifdef USE_SSE
220
- _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
275
+ _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
221
276
  #endif
222
277
 
223
- if (!isMarkedDeleted(candidate_id))
224
- top_candidates.emplace(dist1, candidate_id);
278
+ if (!isMarkedDeleted(candidate_id))
279
+ top_candidates.emplace(dist1, candidate_id);
225
280
 
226
- if (top_candidates.size() > ef_construction_)
227
- top_candidates.pop();
281
+ if (top_candidates.size() > ef_construction_)
282
+ top_candidates.pop();
228
283
 
229
- if (!top_candidates.empty())
230
- lowerBound = top_candidates.top().first;
231
- }
284
+ if (!top_candidates.empty())
285
+ lowerBound = top_candidates.top().first;
232
286
  }
233
287
  }
234
- visited_list_pool_->releaseVisitedList(vl);
235
-
236
- return top_candidates;
288
+ }
289
+ visited_list_pool_->releaseVisitedList(vl);
290
+
291
+ return top_candidates;
292
+ }
293
+
294
+
295
+ template <bool has_deletions, bool collect_metrics = false>
296
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
297
+ searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
298
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
299
+ vl_type *visited_array = vl->mass;
300
+ vl_type visited_array_tag = vl->curV;
301
+
302
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
303
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
304
+
305
+ dist_t lowerBound;
306
+ if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
307
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
308
+ lowerBound = dist;
309
+ top_candidates.emplace(dist, ep_id);
310
+ candidate_set.emplace(-dist, ep_id);
311
+ } else {
312
+ lowerBound = std::numeric_limits<dist_t>::max();
313
+ candidate_set.emplace(-lowerBound, ep_id);
237
314
  }
238
315
 
239
- mutable std::atomic<long> metric_distance_computations;
240
- mutable std::atomic<long> metric_hops;
241
-
242
- template <bool has_deletions, bool collect_metrics=false>
243
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
244
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
245
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
246
- vl_type *visited_array = vl->mass;
247
- vl_type visited_array_tag = vl->curV;
248
-
249
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
250
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
251
-
252
- dist_t lowerBound;
253
- if (!has_deletions || !isMarkedDeleted(ep_id)) {
254
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
255
- lowerBound = dist;
256
- top_candidates.emplace(dist, ep_id);
257
- candidate_set.emplace(-dist, ep_id);
258
- } else {
259
- lowerBound = std::numeric_limits<dist_t>::max();
260
- candidate_set.emplace(-lowerBound, ep_id);
261
- }
262
-
263
- visited_array[ep_id] = visited_array_tag;
264
-
265
- while (!candidate_set.empty()) {
316
+ visited_array[ep_id] = visited_array_tag;
266
317
 
267
- std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
318
+ while (!candidate_set.empty()) {
319
+ std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
268
320
 
269
- if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) {
270
- break;
271
- }
272
- candidate_set.pop();
321
+ if ((-current_node_pair.first) > lowerBound &&
322
+ (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
323
+ break;
324
+ }
325
+ candidate_set.pop();
273
326
 
274
- tableint current_node_id = current_node_pair.second;
275
- int *data = (int *) get_linklist0(current_node_id);
276
- size_t size = getListCount((linklistsizeint*)data);
327
+ tableint current_node_id = current_node_pair.second;
328
+ int *data = (int *) get_linklist0(current_node_id);
329
+ size_t size = getListCount((linklistsizeint*)data);
277
330
  // bool cur_node_deleted = isMarkedDeleted(current_node_id);
278
- if(collect_metrics){
279
- metric_hops++;
280
- metric_distance_computations+=size;
281
- }
331
+ if (collect_metrics) {
332
+ metric_hops++;
333
+ metric_distance_computations+=size;
334
+ }
282
335
 
283
336
  #ifdef USE_SSE
284
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
285
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
286
- _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
287
- _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
337
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
338
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
339
+ _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
340
+ _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
288
341
  #endif
289
342
 
290
- for (size_t j = 1; j <= size; j++) {
291
- int candidate_id = *(data + j);
343
+ for (size_t j = 1; j <= size; j++) {
344
+ int candidate_id = *(data + j);
292
345
  // if (candidate_id == 0) continue;
293
346
  #ifdef USE_SSE
294
- _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
295
- _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
296
- _MM_HINT_T0);////////////
347
+ _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
348
+ _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
349
+ _MM_HINT_T0); ////////////
297
350
  #endif
298
- if (!(visited_array[candidate_id] == visited_array_tag)) {
299
-
300
- visited_array[candidate_id] = visited_array_tag;
351
+ if (!(visited_array[candidate_id] == visited_array_tag)) {
352
+ visited_array[candidate_id] = visited_array_tag;
301
353
 
302
- char *currObj1 = (getDataByInternalId(candidate_id));
303
- dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
354
+ char *currObj1 = (getDataByInternalId(candidate_id));
355
+ dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
304
356
 
305
- if (top_candidates.size() < ef || lowerBound > dist) {
306
- candidate_set.emplace(-dist, candidate_id);
357
+ if (top_candidates.size() < ef || lowerBound > dist) {
358
+ candidate_set.emplace(-dist, candidate_id);
307
359
  #ifdef USE_SSE
308
- _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
309
- offsetLevel0_,///////////
310
- _MM_HINT_T0);////////////////////////
360
+ _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
361
+ offsetLevel0_, ///////////
362
+ _MM_HINT_T0); ////////////////////////
311
363
  #endif
312
364
 
313
- if (!has_deletions || !isMarkedDeleted(candidate_id))
314
- top_candidates.emplace(dist, candidate_id);
365
+ if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
366
+ top_candidates.emplace(dist, candidate_id);
315
367
 
316
- if (top_candidates.size() > ef)
317
- top_candidates.pop();
368
+ if (top_candidates.size() > ef)
369
+ top_candidates.pop();
318
370
 
319
- if (!top_candidates.empty())
320
- lowerBound = top_candidates.top().first;
321
- }
371
+ if (!top_candidates.empty())
372
+ lowerBound = top_candidates.top().first;
322
373
  }
323
374
  }
324
375
  }
325
-
326
- visited_list_pool_->releaseVisitedList(vl);
327
- return top_candidates;
328
376
  }
329
377
 
330
- void getNeighborsByHeuristic2(
331
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
332
- const size_t M) {
333
- if (top_candidates.size() < M) {
334
- return;
335
- }
378
+ visited_list_pool_->releaseVisitedList(vl);
379
+ return top_candidates;
380
+ }
336
381
 
337
- std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
338
- std::vector<std::pair<dist_t, tableint>> return_list;
339
- while (top_candidates.size() > 0) {
340
- queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
341
- top_candidates.pop();
342
- }
343
382
 
344
- while (queue_closest.size()) {
345
- if (return_list.size() >= M)
383
+ void getNeighborsByHeuristic2(
384
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
385
+ const size_t M) {
386
+ if (top_candidates.size() < M) {
387
+ return;
388
+ }
389
+
390
+ std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
391
+ std::vector<std::pair<dist_t, tableint>> return_list;
392
+ while (top_candidates.size() > 0) {
393
+ queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
394
+ top_candidates.pop();
395
+ }
396
+
397
+ while (queue_closest.size()) {
398
+ if (return_list.size() >= M)
399
+ break;
400
+ std::pair<dist_t, tableint> curent_pair = queue_closest.top();
401
+ dist_t dist_to_query = -curent_pair.first;
402
+ queue_closest.pop();
403
+ bool good = true;
404
+
405
+ for (std::pair<dist_t, tableint> second_pair : return_list) {
406
+ dist_t curdist =
407
+ fstdistfunc_(getDataByInternalId(second_pair.second),
408
+ getDataByInternalId(curent_pair.second),
409
+ dist_func_param_);
410
+ if (curdist < dist_to_query) {
411
+ good = false;
346
412
  break;
347
- std::pair<dist_t, tableint> curent_pair = queue_closest.top();
348
- dist_t dist_to_query = -curent_pair.first;
349
- queue_closest.pop();
350
- bool good = true;
351
-
352
- for (std::pair<dist_t, tableint> second_pair : return_list) {
353
- dist_t curdist =
354
- fstdistfunc_(getDataByInternalId(second_pair.second),
355
- getDataByInternalId(curent_pair.second),
356
- dist_func_param_);;
357
- if (curdist < dist_to_query) {
358
- good = false;
359
- break;
360
- }
361
- }
362
- if (good) {
363
- return_list.push_back(curent_pair);
364
413
  }
365
414
  }
366
-
367
- for (std::pair<dist_t, tableint> curent_pair : return_list) {
368
- top_candidates.emplace(-curent_pair.first, curent_pair.second);
415
+ if (good) {
416
+ return_list.push_back(curent_pair);
369
417
  }
370
418
  }
371
419
 
420
+ for (std::pair<dist_t, tableint> curent_pair : return_list) {
421
+ top_candidates.emplace(-curent_pair.first, curent_pair.second);
422
+ }
423
+ }
372
424
 
373
- linklistsizeint *get_linklist0(tableint internal_id) const {
374
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
375
- };
376
425
 
377
- linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
378
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
379
- };
426
+ linklistsizeint *get_linklist0(tableint internal_id) const {
427
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
428
+ }
380
429
 
381
- linklistsizeint *get_linklist(tableint internal_id, int level) const {
382
- return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
383
- };
384
430
 
385
- linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
386
- return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
387
- };
431
+ linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
432
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
433
+ }
388
434
 
389
- tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
390
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
391
- int level, bool isUpdate) {
392
- size_t Mcurmax = level ? maxM_ : maxM0_;
393
- getNeighborsByHeuristic2(top_candidates, M_);
394
- if (top_candidates.size() > M_)
395
- throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
396
435
 
397
- std::vector<tableint> selectedNeighbors;
398
- selectedNeighbors.reserve(M_);
399
- while (top_candidates.size() > 0) {
400
- selectedNeighbors.push_back(top_candidates.top().second);
401
- top_candidates.pop();
402
- }
436
+ linklistsizeint *get_linklist(tableint internal_id, int level) const {
437
+ return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
438
+ }
403
439
 
404
- tableint next_closest_entry_point = selectedNeighbors.back();
405
440
 
406
- {
407
- linklistsizeint *ll_cur;
408
- if (level == 0)
409
- ll_cur = get_linklist0(cur_c);
410
- else
411
- ll_cur = get_linklist(cur_c, level);
441
+ linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
442
+ return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
443
+ }
412
444
 
413
- if (*ll_cur && !isUpdate) {
414
- throw std::runtime_error("The newly inserted element should have blank link list");
415
- }
416
- setListCount(ll_cur,selectedNeighbors.size());
417
- tableint *data = (tableint *) (ll_cur + 1);
418
- for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
419
- if (data[idx] && !isUpdate)
420
- throw std::runtime_error("Possible memory corruption");
421
- if (level > element_levels_[selectedNeighbors[idx]])
422
- throw std::runtime_error("Trying to make a link on a non-existent level");
423
445
 
424
- data[idx] = selectedNeighbors[idx];
446
+ tableint mutuallyConnectNewElement(
447
+ const void *data_point,
448
+ tableint cur_c,
449
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
450
+ int level,
451
+ bool isUpdate) {
452
+ size_t Mcurmax = level ? maxM_ : maxM0_;
453
+ getNeighborsByHeuristic2(top_candidates, M_);
454
+ if (top_candidates.size() > M_)
455
+ throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
425
456
 
426
- }
457
+ std::vector<tableint> selectedNeighbors;
458
+ selectedNeighbors.reserve(M_);
459
+ while (top_candidates.size() > 0) {
460
+ selectedNeighbors.push_back(top_candidates.top().second);
461
+ top_candidates.pop();
462
+ }
463
+
464
+ tableint next_closest_entry_point = selectedNeighbors.back();
465
+
466
+ {
467
+ // lock only during the update
468
+ // because during the addition the lock for cur_c is already acquired
469
+ std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
470
+ if (isUpdate) {
471
+ lock.lock();
427
472
  }
473
+ linklistsizeint *ll_cur;
474
+ if (level == 0)
475
+ ll_cur = get_linklist0(cur_c);
476
+ else
477
+ ll_cur = get_linklist(cur_c, level);
428
478
 
479
+ if (*ll_cur && !isUpdate) {
480
+ throw std::runtime_error("The newly inserted element should have blank link list");
481
+ }
482
+ setListCount(ll_cur, selectedNeighbors.size());
483
+ tableint *data = (tableint *) (ll_cur + 1);
429
484
  for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
485
+ if (data[idx] && !isUpdate)
486
+ throw std::runtime_error("Possible memory corruption");
487
+ if (level > element_levels_[selectedNeighbors[idx]])
488
+ throw std::runtime_error("Trying to make a link on a non-existent level");
430
489
 
431
- std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
490
+ data[idx] = selectedNeighbors[idx];
491
+ }
492
+ }
432
493
 
433
- linklistsizeint *ll_other;
434
- if (level == 0)
435
- ll_other = get_linklist0(selectedNeighbors[idx]);
436
- else
437
- ll_other = get_linklist(selectedNeighbors[idx], level);
494
+ for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
495
+ std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
438
496
 
439
- size_t sz_link_list_other = getListCount(ll_other);
497
+ linklistsizeint *ll_other;
498
+ if (level == 0)
499
+ ll_other = get_linklist0(selectedNeighbors[idx]);
500
+ else
501
+ ll_other = get_linklist(selectedNeighbors[idx], level);
440
502
 
441
- if (sz_link_list_other > Mcurmax)
442
- throw std::runtime_error("Bad value of sz_link_list_other");
443
- if (selectedNeighbors[idx] == cur_c)
444
- throw std::runtime_error("Trying to connect an element to itself");
445
- if (level > element_levels_[selectedNeighbors[idx]])
446
- throw std::runtime_error("Trying to make a link on a non-existent level");
503
+ size_t sz_link_list_other = getListCount(ll_other);
447
504
 
448
- tableint *data = (tableint *) (ll_other + 1);
505
+ if (sz_link_list_other > Mcurmax)
506
+ throw std::runtime_error("Bad value of sz_link_list_other");
507
+ if (selectedNeighbors[idx] == cur_c)
508
+ throw std::runtime_error("Trying to connect an element to itself");
509
+ if (level > element_levels_[selectedNeighbors[idx]])
510
+ throw std::runtime_error("Trying to make a link on a non-existent level");
449
511
 
450
- bool is_cur_c_present = false;
451
- if (isUpdate) {
452
- for (size_t j = 0; j < sz_link_list_other; j++) {
453
- if (data[j] == cur_c) {
454
- is_cur_c_present = true;
455
- break;
456
- }
512
+ tableint *data = (tableint *) (ll_other + 1);
513
+
514
+ bool is_cur_c_present = false;
515
+ if (isUpdate) {
516
+ for (size_t j = 0; j < sz_link_list_other; j++) {
517
+ if (data[j] == cur_c) {
518
+ is_cur_c_present = true;
519
+ break;
457
520
  }
458
521
  }
522
+ }
459
523
 
460
- // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
461
- if (!is_cur_c_present) {
462
- if (sz_link_list_other < Mcurmax) {
463
- data[sz_link_list_other] = cur_c;
464
- setListCount(ll_other, sz_link_list_other + 1);
465
- } else {
466
- // finding the "weakest" element to replace it with the new one
467
- dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
468
- dist_func_param_);
469
- // Heuristic:
470
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
471
- candidates.emplace(d_max, cur_c);
472
-
473
- for (size_t j = 0; j < sz_link_list_other; j++) {
474
- candidates.emplace(
475
- fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
476
- dist_func_param_), data[j]);
477
- }
524
+ // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
525
+ if (!is_cur_c_present) {
526
+ if (sz_link_list_other < Mcurmax) {
527
+ data[sz_link_list_other] = cur_c;
528
+ setListCount(ll_other, sz_link_list_other + 1);
529
+ } else {
530
+ // finding the "weakest" element to replace it with the new one
531
+ dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
532
+ dist_func_param_);
533
+ // Heuristic:
534
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
535
+ candidates.emplace(d_max, cur_c);
478
536
 
479
- getNeighborsByHeuristic2(candidates, Mcurmax);
537
+ for (size_t j = 0; j < sz_link_list_other; j++) {
538
+ candidates.emplace(
539
+ fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
540
+ dist_func_param_), data[j]);
541
+ }
480
542
 
481
- int indx = 0;
482
- while (candidates.size() > 0) {
483
- data[indx] = candidates.top().second;
484
- candidates.pop();
485
- indx++;
486
- }
543
+ getNeighborsByHeuristic2(candidates, Mcurmax);
487
544
 
488
- setListCount(ll_other, indx);
489
- // Nearest K:
490
- /*int indx = -1;
491
- for (int j = 0; j < sz_link_list_other; j++) {
492
- dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
493
- if (d > d_max) {
494
- indx = j;
495
- d_max = d;
496
- }
545
+ int indx = 0;
546
+ while (candidates.size() > 0) {
547
+ data[indx] = candidates.top().second;
548
+ candidates.pop();
549
+ indx++;
550
+ }
551
+
552
+ setListCount(ll_other, indx);
553
+ // Nearest K:
554
+ /*int indx = -1;
555
+ for (int j = 0; j < sz_link_list_other; j++) {
556
+ dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
557
+ if (d > d_max) {
558
+ indx = j;
559
+ d_max = d;
497
560
  }
498
- if (indx >= 0) {
499
- data[indx] = cur_c;
500
- } */
501
561
  }
562
+ if (indx >= 0) {
563
+ data[indx] = cur_c;
564
+ } */
502
565
  }
503
566
  }
504
-
505
- return next_closest_entry_point;
506
567
  }
507
568
 
508
- std::mutex global;
509
- size_t ef_;
569
+ return next_closest_entry_point;
570
+ }
510
571
 
511
- void setEf(size_t ef) {
512
- ef_ = ef;
513
- }
514
572
 
573
+ void resizeIndex(size_t new_max_elements) {
574
+ if (new_max_elements < cur_element_count)
575
+ throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
515
576
 
516
- std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
517
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
518
- if (cur_element_count == 0) return top_candidates;
519
- tableint currObj = enterpoint_node_;
520
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
577
+ delete visited_list_pool_;
578
+ visited_list_pool_ = new VisitedListPool(1, new_max_elements);
521
579
 
522
- for (size_t level = maxlevel_; level > 0; level--) {
523
- bool changed = true;
524
- while (changed) {
525
- changed = false;
526
- int *data;
527
- data = (int *) get_linklist(currObj,level);
528
- int size = getListCount(data);
529
- tableint *datal = (tableint *) (data + 1);
530
- for (int i = 0; i < size; i++) {
531
- tableint cand = datal[i];
532
- if (cand < 0 || cand > max_elements_)
533
- throw std::runtime_error("cand error");
534
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
580
+ element_levels_.resize(new_max_elements);
535
581
 
536
- if (d < curdist) {
537
- curdist = d;
538
- currObj = cand;
539
- changed = true;
540
- }
541
- }
542
- }
543
- }
582
+ std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
544
583
 
545
- if (num_deleted_) {
546
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
547
- ef_);
548
- top_candidates.swap(top_candidates1);
549
- }
550
- else{
551
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
552
- ef_);
553
- top_candidates.swap(top_candidates1);
554
- }
555
-
556
- while (top_candidates.size() > k) {
557
- top_candidates.pop();
558
- }
559
- return top_candidates;
560
- };
584
+ // Reallocate base layer
585
+ char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
586
+ if (data_level0_memory_new == nullptr)
587
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
588
+ data_level0_memory_ = data_level0_memory_new;
561
589
 
562
- void resizeIndex(size_t new_max_elements){
563
- if (new_max_elements<cur_element_count)
564
- throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
590
+ // Reallocate all other layers
591
+ char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
592
+ if (linkLists_new == nullptr)
593
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
594
+ linkLists_ = linkLists_new;
565
595
 
596
+ max_elements_ = new_max_elements;
597
+ }
566
598
 
567
- delete visited_list_pool_;
568
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
569
599
 
600
+ void saveIndex(const std::string &location) {
601
+ std::ofstream output(location, std::ios::binary);
602
+ std::streampos position;
570
603
 
571
- element_levels_.resize(new_max_elements);
604
+ writeBinaryPOD(output, offsetLevel0_);
605
+ writeBinaryPOD(output, max_elements_);
606
+ writeBinaryPOD(output, cur_element_count);
607
+ writeBinaryPOD(output, size_data_per_element_);
608
+ writeBinaryPOD(output, label_offset_);
609
+ writeBinaryPOD(output, offsetData_);
610
+ writeBinaryPOD(output, maxlevel_);
611
+ writeBinaryPOD(output, enterpoint_node_);
612
+ writeBinaryPOD(output, maxM_);
572
613
 
573
- std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
614
+ writeBinaryPOD(output, maxM0_);
615
+ writeBinaryPOD(output, M_);
616
+ writeBinaryPOD(output, mult_);
617
+ writeBinaryPOD(output, ef_construction_);
574
618
 
575
- // Reallocate base layer
576
- char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
577
- if (data_level0_memory_new == nullptr)
578
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
579
- data_level0_memory_ = data_level0_memory_new;
619
+ output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
580
620
 
581
- // Reallocate all other layers
582
- char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
583
- if (linkLists_new == nullptr)
584
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
585
- linkLists_ = linkLists_new;
586
-
587
- max_elements_ = new_max_elements;
621
+ for (size_t i = 0; i < cur_element_count; i++) {
622
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
623
+ writeBinaryPOD(output, linkListSize);
624
+ if (linkListSize)
625
+ output.write(linkLists_[i], linkListSize);
588
626
  }
627
+ output.close();
628
+ }
629
+
630
+
631
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0) {
632
+ std::ifstream input(location, std::ios::binary);
633
+
634
+ if (!input.is_open())
635
+ throw std::runtime_error("Cannot open file");
636
+
637
+ // get file size:
638
+ input.seekg(0, input.end);
639
+ std::streampos total_filesize = input.tellg();
640
+ input.seekg(0, input.beg);
641
+
642
+ readBinaryPOD(input, offsetLevel0_);
643
+ readBinaryPOD(input, max_elements_);
644
+ readBinaryPOD(input, cur_element_count);
645
+
646
+ size_t max_elements = max_elements_i;
647
+ if (max_elements < cur_element_count)
648
+ max_elements = max_elements_;
649
+ max_elements_ = max_elements;
650
+ readBinaryPOD(input, size_data_per_element_);
651
+ readBinaryPOD(input, label_offset_);
652
+ readBinaryPOD(input, offsetData_);
653
+ readBinaryPOD(input, maxlevel_);
654
+ readBinaryPOD(input, enterpoint_node_);
655
+
656
+ readBinaryPOD(input, maxM_);
657
+ readBinaryPOD(input, maxM0_);
658
+ readBinaryPOD(input, M_);
659
+ readBinaryPOD(input, mult_);
660
+ readBinaryPOD(input, ef_construction_);
661
+
662
+ data_size_ = s->get_data_size();
663
+ fstdistfunc_ = s->get_dist_func();
664
+ dist_func_param_ = s->get_dist_func_param();
665
+
666
+ auto pos = input.tellg();
667
+
668
+ /// Optional - check if index is ok:
669
+ input.seekg(cur_element_count * size_data_per_element_, input.cur);
670
+ for (size_t i = 0; i < cur_element_count; i++) {
671
+ if (input.tellg() < 0 || input.tellg() >= total_filesize) {
672
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
673
+ }
589
674
 
590
- void saveIndex(const std::string &location) {
591
- std::ofstream output(location, std::ios::binary);
592
- std::streampos position;
593
-
594
- writeBinaryPOD(output, offsetLevel0_);
595
- writeBinaryPOD(output, max_elements_);
596
- writeBinaryPOD(output, cur_element_count);
597
- writeBinaryPOD(output, size_data_per_element_);
598
- writeBinaryPOD(output, label_offset_);
599
- writeBinaryPOD(output, offsetData_);
600
- writeBinaryPOD(output, maxlevel_);
601
- writeBinaryPOD(output, enterpoint_node_);
602
- writeBinaryPOD(output, maxM_);
603
-
604
- writeBinaryPOD(output, maxM0_);
605
- writeBinaryPOD(output, M_);
606
- writeBinaryPOD(output, mult_);
607
- writeBinaryPOD(output, ef_construction_);
608
-
609
- output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
610
-
611
- for (size_t i = 0; i < cur_element_count; i++) {
612
- unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
613
- writeBinaryPOD(output, linkListSize);
614
- if (linkListSize)
615
- output.write(linkLists_[i], linkListSize);
675
+ unsigned int linkListSize;
676
+ readBinaryPOD(input, linkListSize);
677
+ if (linkListSize != 0) {
678
+ input.seekg(linkListSize, input.cur);
616
679
  }
617
- output.close();
618
680
  }
619
681
 
620
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
621
- std::ifstream input(location, std::ios::binary);
622
-
623
- if (!input.is_open())
624
- throw std::runtime_error("Cannot open file");
625
-
626
- // get file size:
627
- input.seekg(0,input.end);
628
- std::streampos total_filesize=input.tellg();
629
- input.seekg(0,input.beg);
630
-
631
- readBinaryPOD(input, offsetLevel0_);
632
- readBinaryPOD(input, max_elements_);
633
- readBinaryPOD(input, cur_element_count);
682
+ // throw exception if it either corrupted or old index
683
+ if (input.tellg() != total_filesize)
684
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
685
+
686
+ input.clear();
687
+ /// Optional check end
688
+
689
+ input.seekg(pos, input.beg);
690
+
691
+ data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
692
+ if (data_level0_memory_ == nullptr)
693
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
694
+ input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
695
+
696
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
697
+
698
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
699
+ std::vector<std::mutex>(max_elements).swap(link_list_locks_);
700
+ std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
701
+
702
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
703
+
704
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
705
+ if (linkLists_ == nullptr)
706
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
707
+ element_levels_ = std::vector<int>(max_elements);
708
+ revSize_ = 1.0 / mult_;
709
+ ef_ = 10;
710
+ for (size_t i = 0; i < cur_element_count; i++) {
711
+ label_lookup_[getExternalLabel(i)] = i;
712
+ unsigned int linkListSize;
713
+ readBinaryPOD(input, linkListSize);
714
+ if (linkListSize == 0) {
715
+ element_levels_[i] = 0;
716
+ linkLists_[i] = nullptr;
717
+ } else {
718
+ element_levels_[i] = linkListSize / size_links_per_element_;
719
+ linkLists_[i] = (char *) malloc(linkListSize);
720
+ if (linkLists_[i] == nullptr)
721
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
722
+ input.read(linkLists_[i], linkListSize);
723
+ }
724
+ }
634
725
 
635
- size_t max_elements = max_elements_i;
636
- if(max_elements < cur_element_count)
637
- max_elements = max_elements_;
638
- max_elements_ = max_elements;
639
- readBinaryPOD(input, size_data_per_element_);
640
- readBinaryPOD(input, label_offset_);
641
- readBinaryPOD(input, offsetData_);
642
- readBinaryPOD(input, maxlevel_);
643
- readBinaryPOD(input, enterpoint_node_);
726
+ for (size_t i = 0; i < cur_element_count; i++) {
727
+ if (isMarkedDeleted(i)) {
728
+ num_deleted_ += 1;
729
+ if (allow_replace_deleted_) deleted_elements.insert(i);
730
+ }
731
+ }
644
732
 
645
- readBinaryPOD(input, maxM_);
646
- readBinaryPOD(input, maxM0_);
647
- readBinaryPOD(input, M_);
648
- readBinaryPOD(input, mult_);
649
- readBinaryPOD(input, ef_construction_);
733
+ input.close();
650
734
 
735
+ return;
736
+ }
651
737
 
652
- data_size_ = s->get_data_size();
653
- fstdistfunc_ = s->get_dist_func();
654
- dist_func_param_ = s->get_dist_func_param();
655
738
 
656
- auto pos=input.tellg();
739
+ template<typename data_t>
740
+ std::vector<data_t> getDataByLabel(labeltype label) const {
741
+ // lock all operations with element by label
742
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
657
743
 
744
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
745
+ auto search = label_lookup_.find(label);
746
+ if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
747
+ throw std::runtime_error("Label not found");
748
+ }
749
+ tableint internalId = search->second;
750
+ lock_table.unlock();
751
+
752
+ char* data_ptrv = getDataByInternalId(internalId);
753
+ size_t dim = *((size_t *) dist_func_param_);
754
+ std::vector<data_t> data;
755
+ data_t* data_ptr = (data_t*) data_ptrv;
756
+ for (int i = 0; i < dim; i++) {
757
+ data.push_back(*data_ptr);
758
+ data_ptr += 1;
759
+ }
760
+ return data;
761
+ }
658
762
 
659
- /// Optional - check if index is ok:
660
763
 
661
- input.seekg(cur_element_count * size_data_per_element_,input.cur);
662
- for (size_t i = 0; i < cur_element_count; i++) {
663
- if(input.tellg() < 0 || input.tellg()>=total_filesize){
664
- throw std::runtime_error("Index seems to be corrupted or unsupported");
665
- }
764
+ /*
765
+ * Marks an element with the given label deleted, does NOT really change the current graph.
766
+ */
767
+ void markDelete(labeltype label) {
768
+ // lock all operations with element by label
769
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
666
770
 
667
- unsigned int linkListSize;
668
- readBinaryPOD(input, linkListSize);
669
- if (linkListSize != 0) {
670
- input.seekg(linkListSize,input.cur);
671
- }
771
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
772
+ auto search = label_lookup_.find(label);
773
+ if (search == label_lookup_.end()) {
774
+ throw std::runtime_error("Label not found");
775
+ }
776
+ tableint internalId = search->second;
777
+ lock_table.unlock();
778
+
779
+ markDeletedInternal(internalId);
780
+ }
781
+
782
+
783
+ /*
784
+ * Uses the last 16 bits of the memory for the linked list size to store the mark,
785
+ * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases.
786
+ */
787
+ void markDeletedInternal(tableint internalId) {
788
+ assert(internalId < cur_element_count);
789
+ if (!isMarkedDeleted(internalId)) {
790
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
791
+ *ll_cur |= DELETE_MARK;
792
+ num_deleted_ += 1;
793
+ if (allow_replace_deleted_) {
794
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
795
+ deleted_elements.insert(internalId);
672
796
  }
797
+ } else {
798
+ throw std::runtime_error("The requested to delete element is already deleted");
799
+ }
800
+ }
801
+
802
+
803
+ /*
804
+ * Removes the deleted mark of the node, does NOT really change the current graph.
805
+ *
806
+ * Note: the method is not safe to use when replacement of deleted elements is enabled,
807
+ * because elements marked as deleted can be completely removed by addPoint
808
+ */
809
+ void unmarkDelete(labeltype label) {
810
+ // lock all operations with element by label
811
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
812
+
813
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
814
+ auto search = label_lookup_.find(label);
815
+ if (search == label_lookup_.end()) {
816
+ throw std::runtime_error("Label not found");
817
+ }
818
+ tableint internalId = search->second;
819
+ lock_table.unlock();
820
+
821
+ unmarkDeletedInternal(internalId);
822
+ }
823
+
824
+
825
+
826
+ /*
827
+ * Remove the deleted mark of the node.
828
+ */
829
+ void unmarkDeletedInternal(tableint internalId) {
830
+ assert(internalId < cur_element_count);
831
+ if (isMarkedDeleted(internalId)) {
832
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
833
+ *ll_cur &= ~DELETE_MARK;
834
+ num_deleted_ -= 1;
835
+ if (allow_replace_deleted_) {
836
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
837
+ deleted_elements.erase(internalId);
838
+ }
839
+ } else {
840
+ throw std::runtime_error("The requested to undelete element is not deleted");
841
+ }
842
+ }
673
843
 
674
- // throw exception if it either corrupted or old index
675
- if(input.tellg()!=total_filesize)
676
- throw std::runtime_error("Index seems to be corrupted or unsupported");
677
-
678
- input.clear();
679
-
680
- /// Optional check end
681
-
682
- input.seekg(pos,input.beg);
683
-
684
- data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
685
- if (data_level0_memory_ == nullptr)
686
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
687
- input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
688
844
 
689
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
845
+ /*
846
+ * Checks the first 16 bits of the memory to see if the element is marked deleted.
847
+ */
848
+ bool isMarkedDeleted(tableint internalId) const {
849
+ unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2;
850
+ return *ll_cur & DELETE_MARK;
851
+ }
690
852
 
691
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
692
- std::vector<std::mutex>(max_elements).swap(link_list_locks_);
693
- std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);
694
853
 
695
- visited_list_pool_ = new VisitedListPool(1, max_elements);
854
+ unsigned short int getListCount(linklistsizeint * ptr) const {
855
+ return *((unsigned short int *)ptr);
856
+ }
696
857
 
697
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
698
- if (linkLists_ == nullptr)
699
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
700
- element_levels_ = std::vector<int>(max_elements);
701
- revSize_ = 1.0 / mult_;
702
- ef_ = 10;
703
- for (size_t i = 0; i < cur_element_count; i++) {
704
- label_lookup_[getExternalLabel(i)]=i;
705
- unsigned int linkListSize;
706
- readBinaryPOD(input, linkListSize);
707
- if (linkListSize == 0) {
708
- element_levels_[i] = 0;
709
858
 
710
- linkLists_[i] = nullptr;
711
- } else {
712
- element_levels_[i] = linkListSize / size_links_per_element_;
713
- linkLists_[i] = (char *) malloc(linkListSize);
714
- if (linkLists_[i] == nullptr)
715
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
716
- input.read(linkLists_[i], linkListSize);
717
- }
718
- }
719
-
720
- for (size_t i = 0; i < cur_element_count; i++) {
721
- if(isMarkedDeleted(i))
722
- num_deleted_ += 1;
723
- }
859
+ void setListCount(linklistsizeint * ptr, unsigned short int size) const {
860
+ *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
861
+ }
724
862
 
725
- input.close();
726
863
 
727
- return;
864
+ /*
865
+ * Adds point. Updates the point if it is already in the index.
866
+ * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
867
+ */
868
+ void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) {
869
+ if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
870
+ throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
728
871
  }
729
872
 
730
- template<typename data_t>
731
- std::vector<data_t> getDataByLabel(labeltype label) const
732
- {
733
- tableint label_c;
734
- auto search = label_lookup_.find(label);
735
- if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
736
- throw std::runtime_error("Label not found");
737
- }
738
- label_c = search->second;
739
-
740
- char* data_ptrv = getDataByInternalId(label_c);
741
- size_t dim = *((size_t *) dist_func_param_);
742
- std::vector<data_t> data;
743
- data_t* data_ptr = (data_t*) data_ptrv;
744
- for (int i = 0; i < dim; i++) {
745
- data.push_back(*data_ptr);
746
- data_ptr += 1;
747
- }
748
- return data;
873
+ // lock all operations with element by label
874
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
875
+ if (!replace_deleted) {
876
+ addPoint(data_point, label, -1);
877
+ return;
749
878
  }
750
-
751
- static const unsigned char DELETE_MARK = 0x01;
752
- // static const unsigned char REUSE_MARK = 0x10;
753
- /**
754
- * Marks an element with the given label deleted, does NOT really change the current graph.
755
- * @param label
756
- */
757
- void markDelete(labeltype label)
758
- {
759
- auto search = label_lookup_.find(label);
760
- if (search == label_lookup_.end()) {
761
- throw std::runtime_error("Label not found");
762
- }
763
- tableint internalId = search->second;
764
- markDeletedInternal(internalId);
879
+ // check if there is vacant place
880
+ tableint internal_id_replaced;
881
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
882
+ bool is_vacant_place = !deleted_elements.empty();
883
+ if (is_vacant_place) {
884
+ internal_id_replaced = *deleted_elements.begin();
885
+ deleted_elements.erase(internal_id_replaced);
765
886
  }
766
-
767
- /**
768
- * Uses the first 8 bits of the memory for the linked list to store the mark,
769
- * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
770
- * @param internalId
771
- */
772
- void markDeletedInternal(tableint internalId) {
773
- assert(internalId < cur_element_count);
774
- if (!isMarkedDeleted(internalId))
775
- {
776
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
777
- *ll_cur |= DELETE_MARK;
778
- num_deleted_ += 1;
779
- }
780
- else
781
- {
782
- throw std::runtime_error("The requested to delete element is already deleted");
783
- }
887
+ lock_deleted_elements.unlock();
888
+
889
+ // if there is no vacant place then add or update point
890
+ // else add point to vacant place
891
+ if (!is_vacant_place) {
892
+ addPoint(data_point, label, -1);
893
+ } else {
894
+ // we assume that there are no concurrent operations on deleted element
895
+ labeltype label_replaced = getExternalLabel(internal_id_replaced);
896
+ setExternalLabel(internal_id_replaced, label);
897
+
898
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
899
+ label_lookup_.erase(label_replaced);
900
+ label_lookup_[label] = internal_id_replaced;
901
+ lock_table.unlock();
902
+
903
+ unmarkDeletedInternal(internal_id_replaced);
904
+ updatePoint(data_point, internal_id_replaced, 1.0);
784
905
  }
906
+ }
785
907
 
786
- /**
787
- * Remove the deleted mark of the node, does NOT really change the current graph.
788
- * @param label
789
- */
790
- void unmarkDelete(labeltype label)
791
- {
792
- auto search = label_lookup_.find(label);
793
- if (search == label_lookup_.end()) {
794
- throw std::runtime_error("Label not found");
795
- }
796
- tableint internalId = search->second;
797
- unmarkDeletedInternal(internalId);
798
- }
799
908
 
800
- /**
801
- * Remove the deleted mark of the node.
802
- * @param internalId
803
- */
804
- void unmarkDeletedInternal(tableint internalId) {
805
- assert(internalId < cur_element_count);
806
- if (isMarkedDeleted(internalId))
807
- {
808
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
809
- *ll_cur &= ~DELETE_MARK;
810
- num_deleted_ -= 1;
811
- }
812
- else
813
- {
814
- throw std::runtime_error("The requested to undelete element is not deleted");
815
- }
816
- }
909
+ void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
910
+ // update the feature vector associated with existing point with new vector
911
+ memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
817
912
 
818
- /**
819
- * Checks the first 8 bits of the memory to see if the element is marked deleted.
820
- * @param internalId
821
- * @return
822
- */
823
- bool isMarkedDeleted(tableint internalId) const {
824
- unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
825
- return *ll_cur & DELETE_MARK;
826
- }
913
+ int maxLevelCopy = maxlevel_;
914
+ tableint entryPointCopy = enterpoint_node_;
915
+ // If point to be updated is entry point and graph just contains single element then just return.
916
+ if (entryPointCopy == internalId && cur_element_count == 1)
917
+ return;
827
918
 
828
- unsigned short int getListCount(linklistsizeint * ptr) const {
829
- return *((unsigned short int *)ptr);
830
- }
919
+ int elemLevel = element_levels_[internalId];
920
+ std::uniform_real_distribution<float> distribution(0.0, 1.0);
921
+ for (int layer = 0; layer <= elemLevel; layer++) {
922
+ std::unordered_set<tableint> sCand;
923
+ std::unordered_set<tableint> sNeigh;
924
+ std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
925
+ if (listOneHop.size() == 0)
926
+ continue;
831
927
 
832
- void setListCount(linklistsizeint * ptr, unsigned short int size) const {
833
- *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
834
- }
928
+ sCand.insert(internalId);
835
929
 
836
- void addPoint(const void *data_point, labeltype label) {
837
- addPoint(data_point, label,-1);
838
- }
930
+ for (auto&& elOneHop : listOneHop) {
931
+ sCand.insert(elOneHop);
839
932
 
840
- void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
841
- // update the feature vector associated with existing point with new vector
842
- memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
843
-
844
- int maxLevelCopy = maxlevel_;
845
- tableint entryPointCopy = enterpoint_node_;
846
- // If point to be updated is entry point and graph just contains single element then just return.
847
- if (entryPointCopy == internalId && cur_element_count == 1)
848
- return;
849
-
850
- int elemLevel = element_levels_[internalId];
851
- std::uniform_real_distribution<float> distribution(0.0, 1.0);
852
- for (int layer = 0; layer <= elemLevel; layer++) {
853
- std::unordered_set<tableint> sCand;
854
- std::unordered_set<tableint> sNeigh;
855
- std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
856
- if (listOneHop.size() == 0)
933
+ if (distribution(update_probability_generator_) > updateNeighborProbability)
857
934
  continue;
858
935
 
859
- sCand.insert(internalId);
936
+ sNeigh.insert(elOneHop);
860
937
 
861
- for (auto&& elOneHop : listOneHop) {
862
- sCand.insert(elOneHop);
863
-
864
- if (distribution(update_probability_generator_) > updateNeighborProbability)
865
- continue;
866
-
867
- sNeigh.insert(elOneHop);
868
-
869
- std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
870
- for (auto&& elTwoHop : listTwoHop) {
871
- sCand.insert(elTwoHop);
872
- }
938
+ std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
939
+ for (auto&& elTwoHop : listTwoHop) {
940
+ sCand.insert(elTwoHop);
873
941
  }
942
+ }
874
943
 
875
- for (auto&& neigh : sNeigh) {
876
- // if (neigh == internalId)
877
- // continue;
944
+ for (auto&& neigh : sNeigh) {
945
+ // if (neigh == internalId)
946
+ // continue;
878
947
 
879
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
880
- size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
881
- size_t elementsToKeep = std::min(ef_construction_, size);
882
- for (auto&& cand : sCand) {
883
- if (cand == neigh)
884
- continue;
885
-
886
- dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
887
- if (candidates.size() < elementsToKeep) {
948
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
949
+ size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
950
+ size_t elementsToKeep = std::min(ef_construction_, size);
951
+ for (auto&& cand : sCand) {
952
+ if (cand == neigh)
953
+ continue;
954
+
955
+ dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
956
+ if (candidates.size() < elementsToKeep) {
957
+ candidates.emplace(distance, cand);
958
+ } else {
959
+ if (distance < candidates.top().first) {
960
+ candidates.pop();
888
961
  candidates.emplace(distance, cand);
889
- } else {
890
- if (distance < candidates.top().first) {
891
- candidates.pop();
892
- candidates.emplace(distance, cand);
893
- }
894
962
  }
895
963
  }
964
+ }
896
965
 
897
- // Retrieve neighbours using heuristic and set connections.
898
- getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
899
-
900
- {
901
- std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
902
- linklistsizeint *ll_cur;
903
- ll_cur = get_linklist_at_level(neigh, layer);
904
- size_t candSize = candidates.size();
905
- setListCount(ll_cur, candSize);
906
- tableint *data = (tableint *) (ll_cur + 1);
907
- for (size_t idx = 0; idx < candSize; idx++) {
908
- data[idx] = candidates.top().second;
909
- candidates.pop();
910
- }
966
+ // Retrieve neighbours using heuristic and set connections.
967
+ getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
968
+
969
+ {
970
+ std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
971
+ linklistsizeint *ll_cur;
972
+ ll_cur = get_linklist_at_level(neigh, layer);
973
+ size_t candSize = candidates.size();
974
+ setListCount(ll_cur, candSize);
975
+ tableint *data = (tableint *) (ll_cur + 1);
976
+ for (size_t idx = 0; idx < candSize; idx++) {
977
+ data[idx] = candidates.top().second;
978
+ candidates.pop();
911
979
  }
912
980
  }
913
981
  }
982
+ }
914
983
 
915
- repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
916
- };
984
+ repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
985
+ }
917
986
 
918
- void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
919
- tableint currObj = entryPointInternalId;
920
- if (dataPointLevel < maxLevel) {
921
- dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
922
- for (int level = maxLevel; level > dataPointLevel; level--) {
923
- bool changed = true;
924
- while (changed) {
925
- changed = false;
926
- unsigned int *data;
927
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
928
- data = get_linklist_at_level(currObj,level);
929
- int size = getListCount(data);
930
- tableint *datal = (tableint *) (data + 1);
987
+
988
+ void repairConnectionsForUpdate(
989
+ const void *dataPoint,
990
+ tableint entryPointInternalId,
991
+ tableint dataPointInternalId,
992
+ int dataPointLevel,
993
+ int maxLevel) {
994
+ tableint currObj = entryPointInternalId;
995
+ if (dataPointLevel < maxLevel) {
996
+ dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
997
+ for (int level = maxLevel; level > dataPointLevel; level--) {
998
+ bool changed = true;
999
+ while (changed) {
1000
+ changed = false;
1001
+ unsigned int *data;
1002
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1003
+ data = get_linklist_at_level(currObj, level);
1004
+ int size = getListCount(data);
1005
+ tableint *datal = (tableint *) (data + 1);
931
1006
  #ifdef USE_SSE
932
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
1007
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
933
1008
  #endif
934
- for (int i = 0; i < size; i++) {
1009
+ for (int i = 0; i < size; i++) {
935
1010
  #ifdef USE_SSE
936
- _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
1011
+ _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
937
1012
  #endif
938
- tableint cand = datal[i];
939
- dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
940
- if (d < curdist) {
941
- curdist = d;
942
- currObj = cand;
943
- changed = true;
944
- }
1013
+ tableint cand = datal[i];
1014
+ dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
1015
+ if (d < curdist) {
1016
+ curdist = d;
1017
+ currObj = cand;
1018
+ changed = true;
945
1019
  }
946
1020
  }
947
1021
  }
948
1022
  }
1023
+ }
949
1024
 
950
- if (dataPointLevel > maxLevel)
951
- throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
952
-
953
- for (int level = dataPointLevel; level >= 0; level--) {
954
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
955
- currObj, dataPoint, level);
1025
+ if (dataPointLevel > maxLevel)
1026
+ throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
956
1027
 
957
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
958
- while (topCandidates.size() > 0) {
959
- if (topCandidates.top().second != dataPointInternalId)
960
- filteredTopCandidates.push(topCandidates.top());
1028
+ for (int level = dataPointLevel; level >= 0; level--) {
1029
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
1030
+ currObj, dataPoint, level);
961
1031
 
962
- topCandidates.pop();
963
- }
1032
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
1033
+ while (topCandidates.size() > 0) {
1034
+ if (topCandidates.top().second != dataPointInternalId)
1035
+ filteredTopCandidates.push(topCandidates.top());
964
1036
 
965
- // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
966
- // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
967
- if (filteredTopCandidates.size() > 0) {
968
- bool epDeleted = isMarkedDeleted(entryPointInternalId);
969
- if (epDeleted) {
970
- filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
971
- if (filteredTopCandidates.size() > ef_construction_)
972
- filteredTopCandidates.pop();
973
- }
1037
+ topCandidates.pop();
1038
+ }
974
1039
 
975
- currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
1040
+ // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
1041
+ // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
1042
+ if (filteredTopCandidates.size() > 0) {
1043
+ bool epDeleted = isMarkedDeleted(entryPointInternalId);
1044
+ if (epDeleted) {
1045
+ filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
1046
+ if (filteredTopCandidates.size() > ef_construction_)
1047
+ filteredTopCandidates.pop();
976
1048
  }
1049
+
1050
+ currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
977
1051
  }
978
1052
  }
1053
+ }
1054
+
979
1055
 
980
- std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
981
- std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
982
- unsigned int *data = get_linklist_at_level(internalId, level);
983
- int size = getListCount(data);
984
- std::vector<tableint> result(size);
985
- tableint *ll = (tableint *) (data + 1);
986
- memcpy(result.data(), ll,size * sizeof(tableint));
987
- return result;
988
- };
989
-
990
- tableint addPoint(const void *data_point, labeltype label, int level) {
991
-
992
- tableint cur_c = 0;
993
- {
994
- // Checking if the element with the same label already exists
995
- // if so, updating it *instead* of creating a new element.
996
- std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
997
- auto search = label_lookup_.find(label);
998
- if (search != label_lookup_.end()) {
999
- tableint existingInternalId = search->second;
1000
- templock_curr.unlock();
1001
-
1002
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
1056
+ std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
1057
+ std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
1058
+ unsigned int *data = get_linklist_at_level(internalId, level);
1059
+ int size = getListCount(data);
1060
+ std::vector<tableint> result(size);
1061
+ tableint *ll = (tableint *) (data + 1);
1062
+ memcpy(result.data(), ll, size * sizeof(tableint));
1063
+ return result;
1064
+ }
1003
1065
 
1066
+
1067
+ tableint addPoint(const void *data_point, labeltype label, int level) {
1068
+ tableint cur_c = 0;
1069
+ {
1070
+ // Checking if the element with the same label already exists
1071
+ // if so, updating it *instead* of creating a new element.
1072
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
1073
+ auto search = label_lookup_.find(label);
1074
+ if (search != label_lookup_.end()) {
1075
+ tableint existingInternalId = search->second;
1076
+ if (allow_replace_deleted_) {
1004
1077
  if (isMarkedDeleted(existingInternalId)) {
1005
- unmarkDeletedInternal(existingInternalId);
1078
+ throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
1006
1079
  }
1007
- updatePoint(data_point, existingInternalId, 1.0);
1008
-
1009
- return existingInternalId;
1010
1080
  }
1081
+ lock_table.unlock();
1011
1082
 
1012
- if (cur_element_count >= max_elements_) {
1013
- throw std::runtime_error("The number of elements exceeds the specified limit");
1014
- };
1083
+ if (isMarkedDeleted(existingInternalId)) {
1084
+ unmarkDeletedInternal(existingInternalId);
1085
+ }
1086
+ updatePoint(data_point, existingInternalId, 1.0);
1015
1087
 
1016
- cur_c = cur_element_count;
1017
- cur_element_count++;
1018
- label_lookup_[label] = cur_c;
1088
+ return existingInternalId;
1019
1089
  }
1020
1090
 
1021
- // Take update lock to prevent race conditions on an element with insertion/update at the same time.
1022
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
1023
- std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1024
- int curlevel = getRandomLevel(mult_);
1025
- if (level > 0)
1026
- curlevel = level;
1091
+ if (cur_element_count >= max_elements_) {
1092
+ throw std::runtime_error("The number of elements exceeds the specified limit");
1093
+ }
1027
1094
 
1028
- element_levels_[cur_c] = curlevel;
1095
+ cur_c = cur_element_count;
1096
+ cur_element_count++;
1097
+ label_lookup_[label] = cur_c;
1098
+ }
1029
1099
 
1100
+ std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1101
+ int curlevel = getRandomLevel(mult_);
1102
+ if (level > 0)
1103
+ curlevel = level;
1030
1104
 
1031
- std::unique_lock <std::mutex> templock(global);
1032
- int maxlevelcopy = maxlevel_;
1033
- if (curlevel <= maxlevelcopy)
1034
- templock.unlock();
1035
- tableint currObj = enterpoint_node_;
1036
- tableint enterpoint_copy = enterpoint_node_;
1105
+ element_levels_[cur_c] = curlevel;
1037
1106
 
1107
+ std::unique_lock <std::mutex> templock(global);
1108
+ int maxlevelcopy = maxlevel_;
1109
+ if (curlevel <= maxlevelcopy)
1110
+ templock.unlock();
1111
+ tableint currObj = enterpoint_node_;
1112
+ tableint enterpoint_copy = enterpoint_node_;
1038
1113
 
1039
- memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1114
+ memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1040
1115
 
1041
- // Initialisation of the data and label
1042
- memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1043
- memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1116
+ // Initialisation of the data and label
1117
+ memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1118
+ memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1044
1119
 
1120
+ if (curlevel) {
1121
+ linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1122
+ if (linkLists_[cur_c] == nullptr)
1123
+ throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1124
+ memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1125
+ }
1045
1126
 
1046
- if (curlevel) {
1047
- linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1048
- if (linkLists_[cur_c] == nullptr)
1049
- throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1050
- memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1051
- }
1127
+ if ((signed)currObj != -1) {
1128
+ if (curlevel < maxlevelcopy) {
1129
+ dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1130
+ for (int level = maxlevelcopy; level > curlevel; level--) {
1131
+ bool changed = true;
1132
+ while (changed) {
1133
+ changed = false;
1134
+ unsigned int *data;
1135
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1136
+ data = get_linklist(currObj, level);
1137
+ int size = getListCount(data);
1052
1138
 
1053
- if ((signed)currObj != -1) {
1054
-
1055
- if (curlevel < maxlevelcopy) {
1056
-
1057
- dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1058
- for (int level = maxlevelcopy; level > curlevel; level--) {
1059
-
1060
-
1061
- bool changed = true;
1062
- while (changed) {
1063
- changed = false;
1064
- unsigned int *data;
1065
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1066
- data = get_linklist(currObj,level);
1067
- int size = getListCount(data);
1068
-
1069
- tableint *datal = (tableint *) (data + 1);
1070
- for (int i = 0; i < size; i++) {
1071
- tableint cand = datal[i];
1072
- if (cand < 0 || cand > max_elements_)
1073
- throw std::runtime_error("cand error");
1074
- dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1075
- if (d < curdist) {
1076
- curdist = d;
1077
- currObj = cand;
1078
- changed = true;
1079
- }
1139
+ tableint *datal = (tableint *) (data + 1);
1140
+ for (int i = 0; i < size; i++) {
1141
+ tableint cand = datal[i];
1142
+ if (cand < 0 || cand > max_elements_)
1143
+ throw std::runtime_error("cand error");
1144
+ dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1145
+ if (d < curdist) {
1146
+ curdist = d;
1147
+ currObj = cand;
1148
+ changed = true;
1080
1149
  }
1081
1150
  }
1082
1151
  }
1083
1152
  }
1153
+ }
1084
1154
 
1085
- bool epDeleted = isMarkedDeleted(enterpoint_copy);
1086
- for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1087
- if (level > maxlevelcopy || level < 0) // possible?
1088
- throw std::runtime_error("Level error");
1089
-
1090
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1091
- currObj, data_point, level);
1092
- if (epDeleted) {
1093
- top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1094
- if (top_candidates.size() > ef_construction_)
1095
- top_candidates.pop();
1096
- }
1097
- currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1155
+ bool epDeleted = isMarkedDeleted(enterpoint_copy);
1156
+ for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1157
+ if (level > maxlevelcopy || level < 0) // possible?
1158
+ throw std::runtime_error("Level error");
1159
+
1160
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1161
+ currObj, data_point, level);
1162
+ if (epDeleted) {
1163
+ top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1164
+ if (top_candidates.size() > ef_construction_)
1165
+ top_candidates.pop();
1098
1166
  }
1099
-
1100
-
1101
- } else {
1102
- // Do nothing for the first element
1103
- enterpoint_node_ = 0;
1104
- maxlevel_ = curlevel;
1105
-
1167
+ currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1106
1168
  }
1169
+ } else {
1170
+ // Do nothing for the first element
1171
+ enterpoint_node_ = 0;
1172
+ maxlevel_ = curlevel;
1173
+ }
1107
1174
 
1108
- //Releasing lock for the maximum level
1109
- if (curlevel > maxlevelcopy) {
1110
- enterpoint_node_ = cur_c;
1111
- maxlevel_ = curlevel;
1112
- }
1113
- return cur_c;
1114
- };
1175
+ // Releasing lock for the maximum level
1176
+ if (curlevel > maxlevelcopy) {
1177
+ enterpoint_node_ = cur_c;
1178
+ maxlevel_ = curlevel;
1179
+ }
1180
+ return cur_c;
1181
+ }
1115
1182
 
1116
- std::priority_queue<std::pair<dist_t, labeltype >>
1117
- searchKnn(const void *query_data, size_t k) const {
1118
- std::priority_queue<std::pair<dist_t, labeltype >> result;
1119
- if (cur_element_count == 0) return result;
1120
1183
 
1121
- tableint currObj = enterpoint_node_;
1122
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1184
+ std::priority_queue<std::pair<dist_t, labeltype >>
1185
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
1186
+ std::priority_queue<std::pair<dist_t, labeltype >> result;
1187
+ if (cur_element_count == 0) return result;
1123
1188
 
1124
- for (int level = maxlevel_; level > 0; level--) {
1125
- bool changed = true;
1126
- while (changed) {
1127
- changed = false;
1128
- unsigned int *data;
1189
+ tableint currObj = enterpoint_node_;
1190
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1129
1191
 
1130
- data = (unsigned int *) get_linklist(currObj, level);
1131
- int size = getListCount(data);
1132
- metric_hops++;
1133
- metric_distance_computations+=size;
1192
+ for (int level = maxlevel_; level > 0; level--) {
1193
+ bool changed = true;
1194
+ while (changed) {
1195
+ changed = false;
1196
+ unsigned int *data;
1134
1197
 
1135
- tableint *datal = (tableint *) (data + 1);
1136
- for (int i = 0; i < size; i++) {
1137
- tableint cand = datal[i];
1138
- if (cand < 0 || cand > max_elements_)
1139
- throw std::runtime_error("cand error");
1140
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1198
+ data = (unsigned int *) get_linklist(currObj, level);
1199
+ int size = getListCount(data);
1200
+ metric_hops++;
1201
+ metric_distance_computations+=size;
1141
1202
 
1142
- if (d < curdist) {
1143
- curdist = d;
1144
- currObj = cand;
1145
- changed = true;
1146
- }
1203
+ tableint *datal = (tableint *) (data + 1);
1204
+ for (int i = 0; i < size; i++) {
1205
+ tableint cand = datal[i];
1206
+ if (cand < 0 || cand > max_elements_)
1207
+ throw std::runtime_error("cand error");
1208
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1209
+
1210
+ if (d < curdist) {
1211
+ curdist = d;
1212
+ currObj = cand;
1213
+ changed = true;
1147
1214
  }
1148
1215
  }
1149
1216
  }
1217
+ }
1150
1218
 
1151
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1152
- if (num_deleted_) {
1153
- top_candidates=searchBaseLayerST<true,true>(
1154
- currObj, query_data, std::max(ef_, k));
1155
- }
1156
- else{
1157
- top_candidates=searchBaseLayerST<false,true>(
1158
- currObj, query_data, std::max(ef_, k));
1159
- }
1219
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1220
+ if (num_deleted_) {
1221
+ top_candidates = searchBaseLayerST<true, true>(
1222
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1223
+ } else {
1224
+ top_candidates = searchBaseLayerST<false, true>(
1225
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1226
+ }
1160
1227
 
1161
- while (top_candidates.size() > k) {
1162
- top_candidates.pop();
1163
- }
1164
- while (top_candidates.size() > 0) {
1165
- std::pair<dist_t, tableint> rez = top_candidates.top();
1166
- result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1167
- top_candidates.pop();
1168
- }
1169
- return result;
1170
- };
1171
-
1172
- void checkIntegrity(){
1173
- int connections_checked=0;
1174
- std::vector <int > inbound_connections_num(cur_element_count,0);
1175
- for(int i = 0;i < cur_element_count; i++){
1176
- for(int l = 0;l <= element_levels_[i]; l++){
1177
- linklistsizeint *ll_cur = get_linklist_at_level(i,l);
1178
- int size = getListCount(ll_cur);
1179
- tableint *data = (tableint *) (ll_cur + 1);
1180
- std::unordered_set<tableint> s;
1181
- for (int j=0; j<size; j++){
1182
- assert(data[j] > 0);
1183
- assert(data[j] < cur_element_count);
1184
- assert (data[j] != i);
1185
- inbound_connections_num[data[j]]++;
1186
- s.insert(data[j]);
1187
- connections_checked++;
1228
+ while (top_candidates.size() > k) {
1229
+ top_candidates.pop();
1230
+ }
1231
+ while (top_candidates.size() > 0) {
1232
+ std::pair<dist_t, tableint> rez = top_candidates.top();
1233
+ result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1234
+ top_candidates.pop();
1235
+ }
1236
+ return result;
1237
+ }
1188
1238
 
1189
- }
1190
- assert(s.size() == size);
1239
+
1240
+ void checkIntegrity() {
1241
+ int connections_checked = 0;
1242
+ std::vector <int > inbound_connections_num(cur_element_count, 0);
1243
+ for (int i = 0; i < cur_element_count; i++) {
1244
+ for (int l = 0; l <= element_levels_[i]; l++) {
1245
+ linklistsizeint *ll_cur = get_linklist_at_level(i, l);
1246
+ int size = getListCount(ll_cur);
1247
+ tableint *data = (tableint *) (ll_cur + 1);
1248
+ std::unordered_set<tableint> s;
1249
+ for (int j = 0; j < size; j++) {
1250
+ assert(data[j] > 0);
1251
+ assert(data[j] < cur_element_count);
1252
+ assert(data[j] != i);
1253
+ inbound_connections_num[data[j]]++;
1254
+ s.insert(data[j]);
1255
+ connections_checked++;
1191
1256
  }
1257
+ assert(s.size() == size);
1192
1258
  }
1193
- if(cur_element_count > 1){
1194
- int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
1195
- for(int i=0; i < cur_element_count; i++){
1196
- assert(inbound_connections_num[i] > 0);
1197
- min1=std::min(inbound_connections_num[i],min1);
1198
- max1=std::max(inbound_connections_num[i],max1);
1199
- }
1200
- std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1259
+ }
1260
+ if (cur_element_count > 1) {
1261
+ int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
1262
+ for (int i=0; i < cur_element_count; i++) {
1263
+ assert(inbound_connections_num[i] > 0);
1264
+ min1 = std::min(inbound_connections_num[i], min1);
1265
+ max1 = std::max(inbound_connections_num[i], max1);
1201
1266
  }
1202
- std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1203
-
1267
+ std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1204
1268
  }
1205
-
1206
- };
1207
-
1208
- }
1269
+ std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1270
+ }
1271
+ };
1272
+ } // namespace hnswlib