hnswlib 0.6.2 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -10,1199 +10,1263 @@
10
10
  #include <list>
11
11
 
12
12
  namespace hnswlib {
13
- typedef unsigned int tableint;
14
- typedef unsigned int linklistsizeint;
15
-
16
- template<typename dist_t>
17
- class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
- public:
19
- static const tableint max_update_element_locks = 65536;
20
- HierarchicalNSW() : visited_list_pool_(nullptr), data_level0_memory_(nullptr), linkLists_(nullptr), cur_element_count(0) { }
21
- HierarchicalNSW(SpaceInterface<dist_t> *s) {
22
- }
13
+ typedef unsigned int tableint;
14
+ typedef unsigned int linklistsizeint;
15
+
16
+ template<typename dist_t>
17
+ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
+ public:
19
+ static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
20
+ static const unsigned char DELETE_MARK = 0x01;
21
+
22
+ size_t max_elements_{0};
23
+ mutable std::atomic<size_t> cur_element_count{0}; // current number of elements
24
+ size_t size_data_per_element_{0};
25
+ size_t size_links_per_element_{0};
26
+ mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements
27
+ size_t M_{0};
28
+ size_t maxM_{0};
29
+ size_t maxM0_{0};
30
+ size_t ef_construction_{0};
31
+ size_t ef_{ 0 };
32
+
33
+ double mult_{0.0}, revSize_{0.0};
34
+ int maxlevel_{0};
35
+
36
+ VisitedListPool *visited_list_pool_{nullptr};
37
+
38
+ // Locks operations with element by label value
39
+ mutable std::vector<std::mutex> label_op_locks_;
40
+
41
+ std::mutex global;
42
+ std::vector<std::mutex> link_list_locks_;
43
+
44
+ tableint enterpoint_node_{0};
45
+
46
+ size_t size_links_level0_{0};
47
+ size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 };
48
+
49
+ char *data_level0_memory_{nullptr};
50
+ char **linkLists_{nullptr};
51
+ std::vector<int> element_levels_; // keeps level of each element
52
+
53
+ size_t data_size_{0};
54
+
55
+ DISTFUNC<dist_t> fstdistfunc_;
56
+ void *dist_func_param_{nullptr};
57
+
58
+ mutable std::mutex label_lookup_lock; // lock for label_lookup_
59
+ std::unordered_map<labeltype, tableint> label_lookup_;
60
+
61
+ std::default_random_engine level_generator_;
62
+ std::default_random_engine update_probability_generator_;
63
+
64
+ mutable std::atomic<long> metric_distance_computations{0};
65
+ mutable std::atomic<long> metric_hops{0};
66
+
67
+ bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions
68
+
69
+ std::mutex deleted_elements_lock; // lock for deleted_elements
70
+ std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
71
+
72
+ HierarchicalNSW() { }
73
+
74
+ HierarchicalNSW(SpaceInterface<dist_t> *s) {
75
+ }
76
+
77
+
78
+ HierarchicalNSW(
79
+ SpaceInterface<dist_t> *s,
80
+ const std::string &location,
81
+ bool nmslib = false,
82
+ size_t max_elements = 0,
83
+ bool allow_replace_deleted = false)
84
+ : allow_replace_deleted_(allow_replace_deleted) {
85
+ loadIndex(location, s, max_elements);
86
+ }
87
+
88
+
89
+ HierarchicalNSW(
90
+ SpaceInterface<dist_t> *s,
91
+ size_t max_elements,
92
+ size_t M = 16,
93
+ size_t ef_construction = 200,
94
+ size_t random_seed = 100,
95
+ bool allow_replace_deleted = false)
96
+ : link_list_locks_(max_elements),
97
+ label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
98
+ element_levels_(max_elements),
99
+ allow_replace_deleted_(allow_replace_deleted) {
100
+ max_elements_ = max_elements;
101
+ num_deleted_ = 0;
102
+ data_size_ = s->get_data_size();
103
+ fstdistfunc_ = s->get_dist_func();
104
+ dist_func_param_ = s->get_dist_func_param();
105
+ M_ = M;
106
+ maxM_ = M_;
107
+ maxM0_ = M_ * 2;
108
+ ef_construction_ = std::max(ef_construction, M_);
109
+ ef_ = 10;
110
+
111
+ level_generator_.seed(random_seed);
112
+ update_probability_generator_.seed(random_seed + 1);
113
+
114
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
115
+ size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
116
+ offsetData_ = size_links_level0_;
117
+ label_offset_ = size_links_level0_ + data_size_;
118
+ offsetLevel0_ = 0;
119
+
120
+ data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
121
+ if (data_level0_memory_ == nullptr)
122
+ throw std::runtime_error("Not enough memory");
123
+
124
+ cur_element_count = 0;
125
+
126
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
23
127
 
24
- HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
25
- loadIndex(location, s, max_elements);
26
- }
128
+ // initializations for special treatment of the first node
129
+ enterpoint_node_ = -1;
130
+ maxlevel_ = -1;
131
+
132
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
133
+ if (linkLists_ == nullptr)
134
+ throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
135
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
136
+ mult_ = 1 / log(1.0 * M_);
137
+ revSize_ = 1.0 / mult_;
138
+ }
27
139
 
28
- HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
29
- link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
30
- max_elements_ = max_elements;
31
-
32
- num_deleted_ = 0;
33
- data_size_ = s->get_data_size();
34
- fstdistfunc_ = s->get_dist_func();
35
- dist_func_param_ = s->get_dist_func_param();
36
- M_ = M;
37
- maxM_ = M_;
38
- maxM0_ = M_ * 2;
39
- ef_construction_ = std::max(ef_construction,M_);
40
- ef_ = 10;
41
-
42
- level_generator_.seed(random_seed);
43
- update_probability_generator_.seed(random_seed + 1);
44
-
45
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
46
- size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
47
- offsetData_ = size_links_level0_;
48
- label_offset_ = size_links_level0_ + data_size_;
49
- offsetLevel0_ = 0;
50
-
51
- data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
52
- if (data_level0_memory_ == nullptr)
53
- throw std::runtime_error("Not enough memory");
54
-
55
- cur_element_count = 0;
56
-
57
- visited_list_pool_ = new VisitedListPool(1, max_elements);
58
-
59
- //initializations for special treatment of the first node
60
- enterpoint_node_ = -1;
61
- maxlevel_ = -1;
62
-
63
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
64
- if (linkLists_ == nullptr)
65
- throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
66
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
67
- mult_ = 1 / log(1.0 * M_);
68
- revSize_ = 1.0 / mult_;
140
+
141
+ ~HierarchicalNSW() {
142
+ free(data_level0_memory_);
143
+ for (tableint i = 0; i < cur_element_count; i++) {
144
+ if (element_levels_[i] > 0)
145
+ free(linkLists_[i]);
69
146
  }
147
+ free(linkLists_);
148
+ delete visited_list_pool_;
149
+ }
70
150
 
71
- struct CompareByFirst {
72
- constexpr bool operator()(std::pair<dist_t, tableint> const &a,
73
- std::pair<dist_t, tableint> const &b) const noexcept {
74
- return a.first < b.first;
75
- }
76
- };
77
-
78
- ~HierarchicalNSW() {
79
- if (data_level0_memory_) free(data_level0_memory_);
80
- if (linkLists_) {
81
- for (tableint i = 0; i < cur_element_count; i++) {
82
- if (element_levels_[i] > 0)
83
- if (linkLists_[i]) free(linkLists_[i]);
84
- }
85
- free(linkLists_);
86
- }
87
- if (visited_list_pool_) delete visited_list_pool_;
151
+
152
+ struct CompareByFirst {
153
+ constexpr bool operator()(std::pair<dist_t, tableint> const& a,
154
+ std::pair<dist_t, tableint> const& b) const noexcept {
155
+ return a.first < b.first;
88
156
  }
157
+ };
89
158
 
90
- size_t max_elements_;
91
- size_t cur_element_count;
92
- size_t size_data_per_element_;
93
- size_t size_links_per_element_;
94
- size_t num_deleted_;
95
159
 
96
- size_t M_;
97
- size_t maxM_;
98
- size_t maxM0_;
99
- size_t ef_construction_;
160
+ void setEf(size_t ef) {
161
+ ef_ = ef;
162
+ }
100
163
 
101
- double mult_, revSize_;
102
- int maxlevel_;
103
164
 
165
+ inline std::mutex& getLabelOpMutex(labeltype label) const {
166
+ // calculate hash
167
+ size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
168
+ return label_op_locks_[lock_id];
169
+ }
104
170
 
105
- VisitedListPool *visited_list_pool_;
106
- std::mutex cur_element_count_guard_;
107
171
 
108
- std::vector<std::mutex> link_list_locks_;
172
+ inline labeltype getExternalLabel(tableint internal_id) const {
173
+ labeltype return_label;
174
+ memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
175
+ return return_label;
176
+ }
109
177
 
110
- // Locks to prevent race condition during update/insert of an element at same time.
111
- // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
112
- std::vector<std::mutex> link_list_update_locks_;
113
- tableint enterpoint_node_;
114
178
 
115
- size_t size_links_level0_;
116
- size_t offsetData_, offsetLevel0_;
179
+ inline void setExternalLabel(tableint internal_id, labeltype label) const {
180
+ memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
181
+ }
117
182
 
118
- char *data_level0_memory_;
119
- char **linkLists_;
120
- std::vector<int> element_levels_;
121
183
 
122
- size_t data_size_;
184
+ inline labeltype *getExternalLabeLp(tableint internal_id) const {
185
+ return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
186
+ }
123
187
 
124
- size_t label_offset_;
125
- DISTFUNC<dist_t> fstdistfunc_;
126
- void *dist_func_param_;
127
- std::unordered_map<labeltype, tableint> label_lookup_;
128
188
 
129
- std::default_random_engine level_generator_;
130
- std::default_random_engine update_probability_generator_;
189
+ inline char *getDataByInternalId(tableint internal_id) const {
190
+ return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
191
+ }
131
192
 
132
- inline labeltype getExternalLabel(tableint internal_id) const {
133
- labeltype return_label;
134
- memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
135
- return return_label;
136
- }
137
193
 
138
- inline void setExternalLabel(tableint internal_id, labeltype label) const {
139
- memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
140
- }
194
+ int getRandomLevel(double reverse_size) {
195
+ std::uniform_real_distribution<double> distribution(0.0, 1.0);
196
+ double r = -log(distribution(level_generator_)) * reverse_size;
197
+ return (int) r;
198
+ }
141
199
 
142
- inline labeltype *getExternalLabeLp(tableint internal_id) const {
143
- return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
144
- }
200
+ size_t getMaxElements() {
201
+ return max_elements_;
202
+ }
145
203
 
146
- inline char *getDataByInternalId(tableint internal_id) const {
147
- return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
148
- }
204
+ size_t getCurrentElementCount() {
205
+ return cur_element_count;
206
+ }
149
207
 
150
- int getRandomLevel(double reverse_size) {
151
- std::uniform_real_distribution<double> distribution(0.0, 1.0);
152
- double r = -log(distribution(level_generator_)) * reverse_size;
153
- return (int) r;
154
- }
208
+ size_t getDeletedCount() {
209
+ return num_deleted_;
210
+ }
155
211
 
212
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
213
+ searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
214
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
215
+ vl_type *visited_array = vl->mass;
216
+ vl_type visited_array_tag = vl->curV;
156
217
 
157
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
158
- searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
159
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
160
- vl_type *visited_array = vl->mass;
161
- vl_type visited_array_tag = vl->curV;
218
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
219
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
162
220
 
163
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
164
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
221
+ dist_t lowerBound;
222
+ if (!isMarkedDeleted(ep_id)) {
223
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
224
+ top_candidates.emplace(dist, ep_id);
225
+ lowerBound = dist;
226
+ candidateSet.emplace(-dist, ep_id);
227
+ } else {
228
+ lowerBound = std::numeric_limits<dist_t>::max();
229
+ candidateSet.emplace(-lowerBound, ep_id);
230
+ }
231
+ visited_array[ep_id] = visited_array_tag;
165
232
 
166
- dist_t lowerBound;
167
- if (!isMarkedDeleted(ep_id)) {
168
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
169
- top_candidates.emplace(dist, ep_id);
170
- lowerBound = dist;
171
- candidateSet.emplace(-dist, ep_id);
172
- } else {
173
- lowerBound = std::numeric_limits<dist_t>::max();
174
- candidateSet.emplace(-lowerBound, ep_id);
233
+ while (!candidateSet.empty()) {
234
+ std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
235
+ if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
236
+ break;
175
237
  }
176
- visited_array[ep_id] = visited_array_tag;
177
-
178
- while (!candidateSet.empty()) {
179
- std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
180
- if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
181
- break;
182
- }
183
- candidateSet.pop();
238
+ candidateSet.pop();
184
239
 
185
- tableint curNodeNum = curr_el_pair.second;
240
+ tableint curNodeNum = curr_el_pair.second;
186
241
 
187
- std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
242
+ std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
188
243
 
189
- int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
190
- if (layer == 0) {
191
- data = (int*)get_linklist0(curNodeNum);
192
- } else {
193
- data = (int*)get_linklist(curNodeNum, layer);
244
+ int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
245
+ if (layer == 0) {
246
+ data = (int*)get_linklist0(curNodeNum);
247
+ } else {
248
+ data = (int*)get_linklist(curNodeNum, layer);
194
249
  // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
195
- }
196
- size_t size = getListCount((linklistsizeint*)data);
197
- tableint *datal = (tableint *) (data + 1);
250
+ }
251
+ size_t size = getListCount((linklistsizeint*)data);
252
+ tableint *datal = (tableint *) (data + 1);
198
253
  #ifdef USE_SSE
199
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
200
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
201
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
202
- _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
254
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
255
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
256
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
257
+ _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
203
258
  #endif
204
259
 
205
- for (size_t j = 0; j < size; j++) {
206
- tableint candidate_id = *(datal + j);
260
+ for (size_t j = 0; j < size; j++) {
261
+ tableint candidate_id = *(datal + j);
207
262
  // if (candidate_id == 0) continue;
208
263
  #ifdef USE_SSE
209
- _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
210
- _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
264
+ _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
265
+ _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
211
266
  #endif
212
- if (visited_array[candidate_id] == visited_array_tag) continue;
213
- visited_array[candidate_id] = visited_array_tag;
214
- char *currObj1 = (getDataByInternalId(candidate_id));
267
+ if (visited_array[candidate_id] == visited_array_tag) continue;
268
+ visited_array[candidate_id] = visited_array_tag;
269
+ char *currObj1 = (getDataByInternalId(candidate_id));
215
270
 
216
- dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
217
- if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
218
- candidateSet.emplace(-dist1, candidate_id);
271
+ dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
272
+ if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
273
+ candidateSet.emplace(-dist1, candidate_id);
219
274
  #ifdef USE_SSE
220
- _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
275
+ _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
221
276
  #endif
222
277
 
223
- if (!isMarkedDeleted(candidate_id))
224
- top_candidates.emplace(dist1, candidate_id);
278
+ if (!isMarkedDeleted(candidate_id))
279
+ top_candidates.emplace(dist1, candidate_id);
225
280
 
226
- if (top_candidates.size() > ef_construction_)
227
- top_candidates.pop();
281
+ if (top_candidates.size() > ef_construction_)
282
+ top_candidates.pop();
228
283
 
229
- if (!top_candidates.empty())
230
- lowerBound = top_candidates.top().first;
231
- }
284
+ if (!top_candidates.empty())
285
+ lowerBound = top_candidates.top().first;
232
286
  }
233
287
  }
234
- visited_list_pool_->releaseVisitedList(vl);
235
-
236
- return top_candidates;
288
+ }
289
+ visited_list_pool_->releaseVisitedList(vl);
290
+
291
+ return top_candidates;
292
+ }
293
+
294
+
295
+ template <bool has_deletions, bool collect_metrics = false>
296
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
297
+ searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
298
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
299
+ vl_type *visited_array = vl->mass;
300
+ vl_type visited_array_tag = vl->curV;
301
+
302
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
303
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
304
+
305
+ dist_t lowerBound;
306
+ if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
307
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
308
+ lowerBound = dist;
309
+ top_candidates.emplace(dist, ep_id);
310
+ candidate_set.emplace(-dist, ep_id);
311
+ } else {
312
+ lowerBound = std::numeric_limits<dist_t>::max();
313
+ candidate_set.emplace(-lowerBound, ep_id);
237
314
  }
238
315
 
239
- mutable std::atomic<long> metric_distance_computations;
240
- mutable std::atomic<long> metric_hops;
241
-
242
- template <bool has_deletions, bool collect_metrics=false>
243
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
244
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
245
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
246
- vl_type *visited_array = vl->mass;
247
- vl_type visited_array_tag = vl->curV;
248
-
249
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
250
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
251
-
252
- dist_t lowerBound;
253
- if (!has_deletions || !isMarkedDeleted(ep_id)) {
254
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
255
- lowerBound = dist;
256
- top_candidates.emplace(dist, ep_id);
257
- candidate_set.emplace(-dist, ep_id);
258
- } else {
259
- lowerBound = std::numeric_limits<dist_t>::max();
260
- candidate_set.emplace(-lowerBound, ep_id);
261
- }
262
-
263
- visited_array[ep_id] = visited_array_tag;
264
-
265
- while (!candidate_set.empty()) {
316
+ visited_array[ep_id] = visited_array_tag;
266
317
 
267
- std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
318
+ while (!candidate_set.empty()) {
319
+ std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
268
320
 
269
- if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) {
270
- break;
271
- }
272
- candidate_set.pop();
321
+ if ((-current_node_pair.first) > lowerBound &&
322
+ (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
323
+ break;
324
+ }
325
+ candidate_set.pop();
273
326
 
274
- tableint current_node_id = current_node_pair.second;
275
- int *data = (int *) get_linklist0(current_node_id);
276
- size_t size = getListCount((linklistsizeint*)data);
327
+ tableint current_node_id = current_node_pair.second;
328
+ int *data = (int *) get_linklist0(current_node_id);
329
+ size_t size = getListCount((linklistsizeint*)data);
277
330
  // bool cur_node_deleted = isMarkedDeleted(current_node_id);
278
- if(collect_metrics){
279
- metric_hops++;
280
- metric_distance_computations+=size;
281
- }
331
+ if (collect_metrics) {
332
+ metric_hops++;
333
+ metric_distance_computations+=size;
334
+ }
282
335
 
283
336
  #ifdef USE_SSE
284
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
285
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
286
- _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
287
- _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
337
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
338
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
339
+ _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
340
+ _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
288
341
  #endif
289
342
 
290
- for (size_t j = 1; j <= size; j++) {
291
- int candidate_id = *(data + j);
343
+ for (size_t j = 1; j <= size; j++) {
344
+ int candidate_id = *(data + j);
292
345
  // if (candidate_id == 0) continue;
293
346
  #ifdef USE_SSE
294
- _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
295
- _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
296
- _MM_HINT_T0);////////////
347
+ _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
348
+ _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
349
+ _MM_HINT_T0); ////////////
297
350
  #endif
298
- if (!(visited_array[candidate_id] == visited_array_tag)) {
299
-
300
- visited_array[candidate_id] = visited_array_tag;
351
+ if (!(visited_array[candidate_id] == visited_array_tag)) {
352
+ visited_array[candidate_id] = visited_array_tag;
301
353
 
302
- char *currObj1 = (getDataByInternalId(candidate_id));
303
- dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
354
+ char *currObj1 = (getDataByInternalId(candidate_id));
355
+ dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
304
356
 
305
- if (top_candidates.size() < ef || lowerBound > dist) {
306
- candidate_set.emplace(-dist, candidate_id);
357
+ if (top_candidates.size() < ef || lowerBound > dist) {
358
+ candidate_set.emplace(-dist, candidate_id);
307
359
  #ifdef USE_SSE
308
- _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
309
- offsetLevel0_,///////////
310
- _MM_HINT_T0);////////////////////////
360
+ _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
361
+ offsetLevel0_, ///////////
362
+ _MM_HINT_T0); ////////////////////////
311
363
  #endif
312
364
 
313
- if (!has_deletions || !isMarkedDeleted(candidate_id))
314
- top_candidates.emplace(dist, candidate_id);
365
+ if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
366
+ top_candidates.emplace(dist, candidate_id);
315
367
 
316
- if (top_candidates.size() > ef)
317
- top_candidates.pop();
368
+ if (top_candidates.size() > ef)
369
+ top_candidates.pop();
318
370
 
319
- if (!top_candidates.empty())
320
- lowerBound = top_candidates.top().first;
321
- }
371
+ if (!top_candidates.empty())
372
+ lowerBound = top_candidates.top().first;
322
373
  }
323
374
  }
324
375
  }
325
-
326
- visited_list_pool_->releaseVisitedList(vl);
327
- return top_candidates;
328
376
  }
329
377
 
330
- void getNeighborsByHeuristic2(
331
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
332
- const size_t M) {
333
- if (top_candidates.size() < M) {
334
- return;
335
- }
378
+ visited_list_pool_->releaseVisitedList(vl);
379
+ return top_candidates;
380
+ }
336
381
 
337
- std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
338
- std::vector<std::pair<dist_t, tableint>> return_list;
339
- while (top_candidates.size() > 0) {
340
- queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
341
- top_candidates.pop();
342
- }
343
382
 
344
- while (queue_closest.size()) {
345
- if (return_list.size() >= M)
383
+ void getNeighborsByHeuristic2(
384
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
385
+ const size_t M) {
386
+ if (top_candidates.size() < M) {
387
+ return;
388
+ }
389
+
390
+ std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
391
+ std::vector<std::pair<dist_t, tableint>> return_list;
392
+ while (top_candidates.size() > 0) {
393
+ queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
394
+ top_candidates.pop();
395
+ }
396
+
397
+ while (queue_closest.size()) {
398
+ if (return_list.size() >= M)
399
+ break;
400
+ std::pair<dist_t, tableint> curent_pair = queue_closest.top();
401
+ dist_t dist_to_query = -curent_pair.first;
402
+ queue_closest.pop();
403
+ bool good = true;
404
+
405
+ for (std::pair<dist_t, tableint> second_pair : return_list) {
406
+ dist_t curdist =
407
+ fstdistfunc_(getDataByInternalId(second_pair.second),
408
+ getDataByInternalId(curent_pair.second),
409
+ dist_func_param_);
410
+ if (curdist < dist_to_query) {
411
+ good = false;
346
412
  break;
347
- std::pair<dist_t, tableint> curent_pair = queue_closest.top();
348
- dist_t dist_to_query = -curent_pair.first;
349
- queue_closest.pop();
350
- bool good = true;
351
-
352
- for (std::pair<dist_t, tableint> second_pair : return_list) {
353
- dist_t curdist =
354
- fstdistfunc_(getDataByInternalId(second_pair.second),
355
- getDataByInternalId(curent_pair.second),
356
- dist_func_param_);;
357
- if (curdist < dist_to_query) {
358
- good = false;
359
- break;
360
- }
361
- }
362
- if (good) {
363
- return_list.push_back(curent_pair);
364
413
  }
365
414
  }
366
-
367
- for (std::pair<dist_t, tableint> curent_pair : return_list) {
368
- top_candidates.emplace(-curent_pair.first, curent_pair.second);
415
+ if (good) {
416
+ return_list.push_back(curent_pair);
369
417
  }
370
418
  }
371
419
 
420
+ for (std::pair<dist_t, tableint> curent_pair : return_list) {
421
+ top_candidates.emplace(-curent_pair.first, curent_pair.second);
422
+ }
423
+ }
372
424
 
373
- linklistsizeint *get_linklist0(tableint internal_id) const {
374
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
375
- };
376
425
 
377
- linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
378
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
379
- };
426
+ linklistsizeint *get_linklist0(tableint internal_id) const {
427
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
428
+ }
380
429
 
381
- linklistsizeint *get_linklist(tableint internal_id, int level) const {
382
- return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
383
- };
384
430
 
385
- linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
386
- return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
387
- };
431
+ linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
432
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
433
+ }
388
434
 
389
- tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
390
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
391
- int level, bool isUpdate) {
392
- size_t Mcurmax = level ? maxM_ : maxM0_;
393
- getNeighborsByHeuristic2(top_candidates, M_);
394
- if (top_candidates.size() > M_)
395
- throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
396
435
 
397
- std::vector<tableint> selectedNeighbors;
398
- selectedNeighbors.reserve(M_);
399
- while (top_candidates.size() > 0) {
400
- selectedNeighbors.push_back(top_candidates.top().second);
401
- top_candidates.pop();
402
- }
436
+ linklistsizeint *get_linklist(tableint internal_id, int level) const {
437
+ return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
438
+ }
403
439
 
404
- tableint next_closest_entry_point = selectedNeighbors.back();
405
440
 
406
- {
407
- linklistsizeint *ll_cur;
408
- if (level == 0)
409
- ll_cur = get_linklist0(cur_c);
410
- else
411
- ll_cur = get_linklist(cur_c, level);
441
+ linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
442
+ return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
443
+ }
412
444
 
413
- if (*ll_cur && !isUpdate) {
414
- throw std::runtime_error("The newly inserted element should have blank link list");
415
- }
416
- setListCount(ll_cur,selectedNeighbors.size());
417
- tableint *data = (tableint *) (ll_cur + 1);
418
- for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
419
- if (data[idx] && !isUpdate)
420
- throw std::runtime_error("Possible memory corruption");
421
- if (level > element_levels_[selectedNeighbors[idx]])
422
- throw std::runtime_error("Trying to make a link on a non-existent level");
423
445
 
424
- data[idx] = selectedNeighbors[idx];
446
+ tableint mutuallyConnectNewElement(
447
+ const void *data_point,
448
+ tableint cur_c,
449
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
450
+ int level,
451
+ bool isUpdate) {
452
+ size_t Mcurmax = level ? maxM_ : maxM0_;
453
+ getNeighborsByHeuristic2(top_candidates, M_);
454
+ if (top_candidates.size() > M_)
455
+ throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
425
456
 
426
- }
457
+ std::vector<tableint> selectedNeighbors;
458
+ selectedNeighbors.reserve(M_);
459
+ while (top_candidates.size() > 0) {
460
+ selectedNeighbors.push_back(top_candidates.top().second);
461
+ top_candidates.pop();
462
+ }
463
+
464
+ tableint next_closest_entry_point = selectedNeighbors.back();
465
+
466
+ {
467
+ // lock only during the update
468
+ // because during the addition the lock for cur_c is already acquired
469
+ std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
470
+ if (isUpdate) {
471
+ lock.lock();
427
472
  }
473
+ linklistsizeint *ll_cur;
474
+ if (level == 0)
475
+ ll_cur = get_linklist0(cur_c);
476
+ else
477
+ ll_cur = get_linklist(cur_c, level);
428
478
 
479
+ if (*ll_cur && !isUpdate) {
480
+ throw std::runtime_error("The newly inserted element should have blank link list");
481
+ }
482
+ setListCount(ll_cur, selectedNeighbors.size());
483
+ tableint *data = (tableint *) (ll_cur + 1);
429
484
  for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
485
+ if (data[idx] && !isUpdate)
486
+ throw std::runtime_error("Possible memory corruption");
487
+ if (level > element_levels_[selectedNeighbors[idx]])
488
+ throw std::runtime_error("Trying to make a link on a non-existent level");
430
489
 
431
- std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
490
+ data[idx] = selectedNeighbors[idx];
491
+ }
492
+ }
432
493
 
433
- linklistsizeint *ll_other;
434
- if (level == 0)
435
- ll_other = get_linklist0(selectedNeighbors[idx]);
436
- else
437
- ll_other = get_linklist(selectedNeighbors[idx], level);
494
+ for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
495
+ std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
438
496
 
439
- size_t sz_link_list_other = getListCount(ll_other);
497
+ linklistsizeint *ll_other;
498
+ if (level == 0)
499
+ ll_other = get_linklist0(selectedNeighbors[idx]);
500
+ else
501
+ ll_other = get_linklist(selectedNeighbors[idx], level);
440
502
 
441
- if (sz_link_list_other > Mcurmax)
442
- throw std::runtime_error("Bad value of sz_link_list_other");
443
- if (selectedNeighbors[idx] == cur_c)
444
- throw std::runtime_error("Trying to connect an element to itself");
445
- if (level > element_levels_[selectedNeighbors[idx]])
446
- throw std::runtime_error("Trying to make a link on a non-existent level");
503
+ size_t sz_link_list_other = getListCount(ll_other);
447
504
 
448
- tableint *data = (tableint *) (ll_other + 1);
505
+ if (sz_link_list_other > Mcurmax)
506
+ throw std::runtime_error("Bad value of sz_link_list_other");
507
+ if (selectedNeighbors[idx] == cur_c)
508
+ throw std::runtime_error("Trying to connect an element to itself");
509
+ if (level > element_levels_[selectedNeighbors[idx]])
510
+ throw std::runtime_error("Trying to make a link on a non-existent level");
449
511
 
450
- bool is_cur_c_present = false;
451
- if (isUpdate) {
452
- for (size_t j = 0; j < sz_link_list_other; j++) {
453
- if (data[j] == cur_c) {
454
- is_cur_c_present = true;
455
- break;
456
- }
512
+ tableint *data = (tableint *) (ll_other + 1);
513
+
514
+ bool is_cur_c_present = false;
515
+ if (isUpdate) {
516
+ for (size_t j = 0; j < sz_link_list_other; j++) {
517
+ if (data[j] == cur_c) {
518
+ is_cur_c_present = true;
519
+ break;
457
520
  }
458
521
  }
522
+ }
459
523
 
460
- // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
461
- if (!is_cur_c_present) {
462
- if (sz_link_list_other < Mcurmax) {
463
- data[sz_link_list_other] = cur_c;
464
- setListCount(ll_other, sz_link_list_other + 1);
465
- } else {
466
- // finding the "weakest" element to replace it with the new one
467
- dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
468
- dist_func_param_);
469
- // Heuristic:
470
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
471
- candidates.emplace(d_max, cur_c);
472
-
473
- for (size_t j = 0; j < sz_link_list_other; j++) {
474
- candidates.emplace(
475
- fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
476
- dist_func_param_), data[j]);
477
- }
524
+ // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
525
+ if (!is_cur_c_present) {
526
+ if (sz_link_list_other < Mcurmax) {
527
+ data[sz_link_list_other] = cur_c;
528
+ setListCount(ll_other, sz_link_list_other + 1);
529
+ } else {
530
+ // finding the "weakest" element to replace it with the new one
531
+ dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
532
+ dist_func_param_);
533
+ // Heuristic:
534
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
535
+ candidates.emplace(d_max, cur_c);
478
536
 
479
- getNeighborsByHeuristic2(candidates, Mcurmax);
537
+ for (size_t j = 0; j < sz_link_list_other; j++) {
538
+ candidates.emplace(
539
+ fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
540
+ dist_func_param_), data[j]);
541
+ }
480
542
 
481
- int indx = 0;
482
- while (candidates.size() > 0) {
483
- data[indx] = candidates.top().second;
484
- candidates.pop();
485
- indx++;
486
- }
543
+ getNeighborsByHeuristic2(candidates, Mcurmax);
487
544
 
488
- setListCount(ll_other, indx);
489
- // Nearest K:
490
- /*int indx = -1;
491
- for (int j = 0; j < sz_link_list_other; j++) {
492
- dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
493
- if (d > d_max) {
494
- indx = j;
495
- d_max = d;
496
- }
545
+ int indx = 0;
546
+ while (candidates.size() > 0) {
547
+ data[indx] = candidates.top().second;
548
+ candidates.pop();
549
+ indx++;
550
+ }
551
+
552
+ setListCount(ll_other, indx);
553
+ // Nearest K:
554
+ /*int indx = -1;
555
+ for (int j = 0; j < sz_link_list_other; j++) {
556
+ dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
557
+ if (d > d_max) {
558
+ indx = j;
559
+ d_max = d;
497
560
  }
498
- if (indx >= 0) {
499
- data[indx] = cur_c;
500
- } */
501
561
  }
562
+ if (indx >= 0) {
563
+ data[indx] = cur_c;
564
+ } */
502
565
  }
503
566
  }
504
-
505
- return next_closest_entry_point;
506
567
  }
507
568
 
508
- std::mutex global;
509
- size_t ef_;
569
+ return next_closest_entry_point;
570
+ }
510
571
 
511
- void setEf(size_t ef) {
512
- ef_ = ef;
513
- }
514
572
 
573
+ void resizeIndex(size_t new_max_elements) {
574
+ if (new_max_elements < cur_element_count)
575
+ throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
515
576
 
516
- std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
517
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
518
- if (cur_element_count == 0) return top_candidates;
519
- tableint currObj = enterpoint_node_;
520
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
577
+ delete visited_list_pool_;
578
+ visited_list_pool_ = new VisitedListPool(1, new_max_elements);
521
579
 
522
- for (size_t level = maxlevel_; level > 0; level--) {
523
- bool changed = true;
524
- while (changed) {
525
- changed = false;
526
- int *data;
527
- data = (int *) get_linklist(currObj,level);
528
- int size = getListCount(data);
529
- tableint *datal = (tableint *) (data + 1);
530
- for (int i = 0; i < size; i++) {
531
- tableint cand = datal[i];
532
- if (cand < 0 || cand > max_elements_)
533
- throw std::runtime_error("cand error");
534
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
580
+ element_levels_.resize(new_max_elements);
535
581
 
536
- if (d < curdist) {
537
- curdist = d;
538
- currObj = cand;
539
- changed = true;
540
- }
541
- }
542
- }
543
- }
582
+ std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
544
583
 
545
- if (num_deleted_) {
546
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
547
- ef_);
548
- top_candidates.swap(top_candidates1);
549
- }
550
- else{
551
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
552
- ef_);
553
- top_candidates.swap(top_candidates1);
554
- }
555
-
556
- while (top_candidates.size() > k) {
557
- top_candidates.pop();
558
- }
559
- return top_candidates;
560
- };
584
+ // Reallocate base layer
585
+ char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
586
+ if (data_level0_memory_new == nullptr)
587
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
588
+ data_level0_memory_ = data_level0_memory_new;
561
589
 
562
- void resizeIndex(size_t new_max_elements){
563
- if (new_max_elements<cur_element_count)
564
- throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
590
+ // Reallocate all other layers
591
+ char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
592
+ if (linkLists_new == nullptr)
593
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
594
+ linkLists_ = linkLists_new;
565
595
 
596
+ max_elements_ = new_max_elements;
597
+ }
566
598
 
567
- delete visited_list_pool_;
568
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
569
599
 
600
+ void saveIndex(const std::string &location) {
601
+ std::ofstream output(location, std::ios::binary);
602
+ std::streampos position;
570
603
 
571
- element_levels_.resize(new_max_elements);
604
+ writeBinaryPOD(output, offsetLevel0_);
605
+ writeBinaryPOD(output, max_elements_);
606
+ writeBinaryPOD(output, cur_element_count);
607
+ writeBinaryPOD(output, size_data_per_element_);
608
+ writeBinaryPOD(output, label_offset_);
609
+ writeBinaryPOD(output, offsetData_);
610
+ writeBinaryPOD(output, maxlevel_);
611
+ writeBinaryPOD(output, enterpoint_node_);
612
+ writeBinaryPOD(output, maxM_);
572
613
 
573
- std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
614
+ writeBinaryPOD(output, maxM0_);
615
+ writeBinaryPOD(output, M_);
616
+ writeBinaryPOD(output, mult_);
617
+ writeBinaryPOD(output, ef_construction_);
574
618
 
575
- // Reallocate base layer
576
- char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
577
- if (data_level0_memory_new == nullptr)
578
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
579
- data_level0_memory_ = data_level0_memory_new;
619
+ output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
580
620
 
581
- // Reallocate all other layers
582
- char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
583
- if (linkLists_new == nullptr)
584
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
585
- linkLists_ = linkLists_new;
586
-
587
- max_elements_ = new_max_elements;
621
+ for (size_t i = 0; i < cur_element_count; i++) {
622
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
623
+ writeBinaryPOD(output, linkListSize);
624
+ if (linkListSize)
625
+ output.write(linkLists_[i], linkListSize);
588
626
  }
627
+ output.close();
628
+ }
629
+
630
+
631
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0) {
632
+ std::ifstream input(location, std::ios::binary);
633
+
634
+ if (!input.is_open())
635
+ throw std::runtime_error("Cannot open file");
636
+
637
+ // get file size:
638
+ input.seekg(0, input.end);
639
+ std::streampos total_filesize = input.tellg();
640
+ input.seekg(0, input.beg);
641
+
642
+ readBinaryPOD(input, offsetLevel0_);
643
+ readBinaryPOD(input, max_elements_);
644
+ readBinaryPOD(input, cur_element_count);
645
+
646
+ size_t max_elements = max_elements_i;
647
+ if (max_elements < cur_element_count)
648
+ max_elements = max_elements_;
649
+ max_elements_ = max_elements;
650
+ readBinaryPOD(input, size_data_per_element_);
651
+ readBinaryPOD(input, label_offset_);
652
+ readBinaryPOD(input, offsetData_);
653
+ readBinaryPOD(input, maxlevel_);
654
+ readBinaryPOD(input, enterpoint_node_);
655
+
656
+ readBinaryPOD(input, maxM_);
657
+ readBinaryPOD(input, maxM0_);
658
+ readBinaryPOD(input, M_);
659
+ readBinaryPOD(input, mult_);
660
+ readBinaryPOD(input, ef_construction_);
661
+
662
+ data_size_ = s->get_data_size();
663
+ fstdistfunc_ = s->get_dist_func();
664
+ dist_func_param_ = s->get_dist_func_param();
665
+
666
+ auto pos = input.tellg();
667
+
668
+ /// Optional - check if index is ok:
669
+ input.seekg(cur_element_count * size_data_per_element_, input.cur);
670
+ for (size_t i = 0; i < cur_element_count; i++) {
671
+ if (input.tellg() < 0 || input.tellg() >= total_filesize) {
672
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
673
+ }
589
674
 
590
- void saveIndex(const std::string &location) {
591
- std::ofstream output(location, std::ios::binary);
592
- std::streampos position;
593
-
594
- writeBinaryPOD(output, offsetLevel0_);
595
- writeBinaryPOD(output, max_elements_);
596
- writeBinaryPOD(output, cur_element_count);
597
- writeBinaryPOD(output, size_data_per_element_);
598
- writeBinaryPOD(output, label_offset_);
599
- writeBinaryPOD(output, offsetData_);
600
- writeBinaryPOD(output, maxlevel_);
601
- writeBinaryPOD(output, enterpoint_node_);
602
- writeBinaryPOD(output, maxM_);
603
-
604
- writeBinaryPOD(output, maxM0_);
605
- writeBinaryPOD(output, M_);
606
- writeBinaryPOD(output, mult_);
607
- writeBinaryPOD(output, ef_construction_);
608
-
609
- output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
610
-
611
- for (size_t i = 0; i < cur_element_count; i++) {
612
- unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
613
- writeBinaryPOD(output, linkListSize);
614
- if (linkListSize)
615
- output.write(linkLists_[i], linkListSize);
675
+ unsigned int linkListSize;
676
+ readBinaryPOD(input, linkListSize);
677
+ if (linkListSize != 0) {
678
+ input.seekg(linkListSize, input.cur);
616
679
  }
617
- output.close();
618
680
  }
619
681
 
620
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
621
- std::ifstream input(location, std::ios::binary);
622
-
623
- if (!input.is_open())
624
- throw std::runtime_error("Cannot open file");
625
-
626
- // get file size:
627
- input.seekg(0,input.end);
628
- std::streampos total_filesize=input.tellg();
629
- input.seekg(0,input.beg);
630
-
631
- readBinaryPOD(input, offsetLevel0_);
632
- readBinaryPOD(input, max_elements_);
633
- readBinaryPOD(input, cur_element_count);
682
+ // throw exception if it either corrupted or old index
683
+ if (input.tellg() != total_filesize)
684
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
685
+
686
+ input.clear();
687
+ /// Optional check end
688
+
689
+ input.seekg(pos, input.beg);
690
+
691
+ data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
692
+ if (data_level0_memory_ == nullptr)
693
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
694
+ input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
695
+
696
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
697
+
698
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
699
+ std::vector<std::mutex>(max_elements).swap(link_list_locks_);
700
+ std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
701
+
702
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
703
+
704
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
705
+ if (linkLists_ == nullptr)
706
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
707
+ element_levels_ = std::vector<int>(max_elements);
708
+ revSize_ = 1.0 / mult_;
709
+ ef_ = 10;
710
+ for (size_t i = 0; i < cur_element_count; i++) {
711
+ label_lookup_[getExternalLabel(i)] = i;
712
+ unsigned int linkListSize;
713
+ readBinaryPOD(input, linkListSize);
714
+ if (linkListSize == 0) {
715
+ element_levels_[i] = 0;
716
+ linkLists_[i] = nullptr;
717
+ } else {
718
+ element_levels_[i] = linkListSize / size_links_per_element_;
719
+ linkLists_[i] = (char *) malloc(linkListSize);
720
+ if (linkLists_[i] == nullptr)
721
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
722
+ input.read(linkLists_[i], linkListSize);
723
+ }
724
+ }
634
725
 
635
- size_t max_elements = max_elements_i;
636
- if(max_elements < cur_element_count)
637
- max_elements = max_elements_;
638
- max_elements_ = max_elements;
639
- readBinaryPOD(input, size_data_per_element_);
640
- readBinaryPOD(input, label_offset_);
641
- readBinaryPOD(input, offsetData_);
642
- readBinaryPOD(input, maxlevel_);
643
- readBinaryPOD(input, enterpoint_node_);
726
+ for (size_t i = 0; i < cur_element_count; i++) {
727
+ if (isMarkedDeleted(i)) {
728
+ num_deleted_ += 1;
729
+ if (allow_replace_deleted_) deleted_elements.insert(i);
730
+ }
731
+ }
644
732
 
645
- readBinaryPOD(input, maxM_);
646
- readBinaryPOD(input, maxM0_);
647
- readBinaryPOD(input, M_);
648
- readBinaryPOD(input, mult_);
649
- readBinaryPOD(input, ef_construction_);
733
+ input.close();
650
734
 
735
+ return;
736
+ }
651
737
 
652
- data_size_ = s->get_data_size();
653
- fstdistfunc_ = s->get_dist_func();
654
- dist_func_param_ = s->get_dist_func_param();
655
738
 
656
- auto pos=input.tellg();
739
+ template<typename data_t>
740
+ std::vector<data_t> getDataByLabel(labeltype label) const {
741
+ // lock all operations with element by label
742
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
657
743
 
744
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
745
+ auto search = label_lookup_.find(label);
746
+ if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
747
+ throw std::runtime_error("Label not found");
748
+ }
749
+ tableint internalId = search->second;
750
+ lock_table.unlock();
751
+
752
+ char* data_ptrv = getDataByInternalId(internalId);
753
+ size_t dim = *((size_t *) dist_func_param_);
754
+ std::vector<data_t> data;
755
+ data_t* data_ptr = (data_t*) data_ptrv;
756
+ for (int i = 0; i < dim; i++) {
757
+ data.push_back(*data_ptr);
758
+ data_ptr += 1;
759
+ }
760
+ return data;
761
+ }
658
762
 
659
- /// Optional - check if index is ok:
660
763
 
661
- input.seekg(cur_element_count * size_data_per_element_,input.cur);
662
- for (size_t i = 0; i < cur_element_count; i++) {
663
- if(input.tellg() < 0 || input.tellg()>=total_filesize){
664
- throw std::runtime_error("Index seems to be corrupted or unsupported");
665
- }
764
+ /*
765
+ * Marks an element with the given label deleted, does NOT really change the current graph.
766
+ */
767
+ void markDelete(labeltype label) {
768
+ // lock all operations with element by label
769
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
666
770
 
667
- unsigned int linkListSize;
668
- readBinaryPOD(input, linkListSize);
669
- if (linkListSize != 0) {
670
- input.seekg(linkListSize,input.cur);
671
- }
771
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
772
+ auto search = label_lookup_.find(label);
773
+ if (search == label_lookup_.end()) {
774
+ throw std::runtime_error("Label not found");
775
+ }
776
+ tableint internalId = search->second;
777
+ lock_table.unlock();
778
+
779
+ markDeletedInternal(internalId);
780
+ }
781
+
782
+
783
+ /*
784
+ * Uses the last 16 bits of the memory for the linked list size to store the mark,
785
+ * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases.
786
+ */
787
+ void markDeletedInternal(tableint internalId) {
788
+ assert(internalId < cur_element_count);
789
+ if (!isMarkedDeleted(internalId)) {
790
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
791
+ *ll_cur |= DELETE_MARK;
792
+ num_deleted_ += 1;
793
+ if (allow_replace_deleted_) {
794
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
795
+ deleted_elements.insert(internalId);
672
796
  }
797
+ } else {
798
+ throw std::runtime_error("The requested to delete element is already deleted");
799
+ }
800
+ }
801
+
802
+
803
+ /*
804
+ * Removes the deleted mark of the node, does NOT really change the current graph.
805
+ *
806
+ * Note: the method is not safe to use when replacement of deleted elements is enabled,
807
+ * because elements marked as deleted can be completely removed by addPoint
808
+ */
809
+ void unmarkDelete(labeltype label) {
810
+ // lock all operations with element by label
811
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
812
+
813
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
814
+ auto search = label_lookup_.find(label);
815
+ if (search == label_lookup_.end()) {
816
+ throw std::runtime_error("Label not found");
817
+ }
818
+ tableint internalId = search->second;
819
+ lock_table.unlock();
820
+
821
+ unmarkDeletedInternal(internalId);
822
+ }
823
+
824
+
825
+
826
+ /*
827
+ * Remove the deleted mark of the node.
828
+ */
829
+ void unmarkDeletedInternal(tableint internalId) {
830
+ assert(internalId < cur_element_count);
831
+ if (isMarkedDeleted(internalId)) {
832
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
833
+ *ll_cur &= ~DELETE_MARK;
834
+ num_deleted_ -= 1;
835
+ if (allow_replace_deleted_) {
836
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
837
+ deleted_elements.erase(internalId);
838
+ }
839
+ } else {
840
+ throw std::runtime_error("The requested to undelete element is not deleted");
841
+ }
842
+ }
673
843
 
674
- // throw exception if it either corrupted or old index
675
- if(input.tellg()!=total_filesize)
676
- throw std::runtime_error("Index seems to be corrupted or unsupported");
677
-
678
- input.clear();
679
-
680
- /// Optional check end
681
-
682
- input.seekg(pos,input.beg);
683
-
684
- data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
685
- if (data_level0_memory_ == nullptr)
686
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
687
- input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
688
844
 
689
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
845
+ /*
846
+ * Checks the first 16 bits of the memory to see if the element is marked deleted.
847
+ */
848
+ bool isMarkedDeleted(tableint internalId) const {
849
+ unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2;
850
+ return *ll_cur & DELETE_MARK;
851
+ }
690
852
 
691
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
692
- std::vector<std::mutex>(max_elements).swap(link_list_locks_);
693
- std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);
694
853
 
695
- visited_list_pool_ = new VisitedListPool(1, max_elements);
854
+ unsigned short int getListCount(linklistsizeint * ptr) const {
855
+ return *((unsigned short int *)ptr);
856
+ }
696
857
 
697
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
698
- if (linkLists_ == nullptr)
699
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
700
- element_levels_ = std::vector<int>(max_elements);
701
- revSize_ = 1.0 / mult_;
702
- ef_ = 10;
703
- for (size_t i = 0; i < cur_element_count; i++) {
704
- label_lookup_[getExternalLabel(i)]=i;
705
- unsigned int linkListSize;
706
- readBinaryPOD(input, linkListSize);
707
- if (linkListSize == 0) {
708
- element_levels_[i] = 0;
709
858
 
710
- linkLists_[i] = nullptr;
711
- } else {
712
- element_levels_[i] = linkListSize / size_links_per_element_;
713
- linkLists_[i] = (char *) malloc(linkListSize);
714
- if (linkLists_[i] == nullptr)
715
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
716
- input.read(linkLists_[i], linkListSize);
717
- }
718
- }
719
-
720
- for (size_t i = 0; i < cur_element_count; i++) {
721
- if(isMarkedDeleted(i))
722
- num_deleted_ += 1;
723
- }
859
+ void setListCount(linklistsizeint * ptr, unsigned short int size) const {
860
+ *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
861
+ }
724
862
 
725
- input.close();
726
863
 
727
- return;
864
+ /*
865
+ * Adds point. Updates the point if it is already in the index.
866
+ * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
867
+ */
868
+ void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) {
869
+ if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
870
+ throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
728
871
  }
729
872
 
730
- template<typename data_t>
731
- std::vector<data_t> getDataByLabel(labeltype label) const
732
- {
733
- tableint label_c;
734
- auto search = label_lookup_.find(label);
735
- if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
736
- throw std::runtime_error("Label not found");
737
- }
738
- label_c = search->second;
739
-
740
- char* data_ptrv = getDataByInternalId(label_c);
741
- size_t dim = *((size_t *) dist_func_param_);
742
- std::vector<data_t> data;
743
- data_t* data_ptr = (data_t*) data_ptrv;
744
- for (int i = 0; i < dim; i++) {
745
- data.push_back(*data_ptr);
746
- data_ptr += 1;
747
- }
748
- return data;
873
+ // lock all operations with element by label
874
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
875
+ if (!replace_deleted) {
876
+ addPoint(data_point, label, -1);
877
+ return;
749
878
  }
750
-
751
- static const unsigned char DELETE_MARK = 0x01;
752
- // static const unsigned char REUSE_MARK = 0x10;
753
- /**
754
- * Marks an element with the given label deleted, does NOT really change the current graph.
755
- * @param label
756
- */
757
- void markDelete(labeltype label)
758
- {
759
- auto search = label_lookup_.find(label);
760
- if (search == label_lookup_.end()) {
761
- throw std::runtime_error("Label not found");
762
- }
763
- tableint internalId = search->second;
764
- markDeletedInternal(internalId);
879
+ // check if there is vacant place
880
+ tableint internal_id_replaced;
881
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
882
+ bool is_vacant_place = !deleted_elements.empty();
883
+ if (is_vacant_place) {
884
+ internal_id_replaced = *deleted_elements.begin();
885
+ deleted_elements.erase(internal_id_replaced);
765
886
  }
766
-
767
- /**
768
- * Uses the first 8 bits of the memory for the linked list to store the mark,
769
- * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
770
- * @param internalId
771
- */
772
- void markDeletedInternal(tableint internalId) {
773
- assert(internalId < cur_element_count);
774
- if (!isMarkedDeleted(internalId))
775
- {
776
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
777
- *ll_cur |= DELETE_MARK;
778
- num_deleted_ += 1;
779
- }
780
- else
781
- {
782
- throw std::runtime_error("The requested to delete element is already deleted");
783
- }
887
+ lock_deleted_elements.unlock();
888
+
889
+ // if there is no vacant place then add or update point
890
+ // else add point to vacant place
891
+ if (!is_vacant_place) {
892
+ addPoint(data_point, label, -1);
893
+ } else {
894
+ // we assume that there are no concurrent operations on deleted element
895
+ labeltype label_replaced = getExternalLabel(internal_id_replaced);
896
+ setExternalLabel(internal_id_replaced, label);
897
+
898
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
899
+ label_lookup_.erase(label_replaced);
900
+ label_lookup_[label] = internal_id_replaced;
901
+ lock_table.unlock();
902
+
903
+ unmarkDeletedInternal(internal_id_replaced);
904
+ updatePoint(data_point, internal_id_replaced, 1.0);
784
905
  }
906
+ }
785
907
 
786
- /**
787
- * Remove the deleted mark of the node, does NOT really change the current graph.
788
- * @param label
789
- */
790
- void unmarkDelete(labeltype label)
791
- {
792
- auto search = label_lookup_.find(label);
793
- if (search == label_lookup_.end()) {
794
- throw std::runtime_error("Label not found");
795
- }
796
- tableint internalId = search->second;
797
- unmarkDeletedInternal(internalId);
798
- }
799
908
 
800
- /**
801
- * Remove the deleted mark of the node.
802
- * @param internalId
803
- */
804
- void unmarkDeletedInternal(tableint internalId) {
805
- assert(internalId < cur_element_count);
806
- if (isMarkedDeleted(internalId))
807
- {
808
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
809
- *ll_cur &= ~DELETE_MARK;
810
- num_deleted_ -= 1;
811
- }
812
- else
813
- {
814
- throw std::runtime_error("The requested to undelete element is not deleted");
815
- }
816
- }
909
+ void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
910
+ // update the feature vector associated with existing point with new vector
911
+ memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
817
912
 
818
- /**
819
- * Checks the first 8 bits of the memory to see if the element is marked deleted.
820
- * @param internalId
821
- * @return
822
- */
823
- bool isMarkedDeleted(tableint internalId) const {
824
- unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
825
- return *ll_cur & DELETE_MARK;
826
- }
913
+ int maxLevelCopy = maxlevel_;
914
+ tableint entryPointCopy = enterpoint_node_;
915
+ // If point to be updated is entry point and graph just contains single element then just return.
916
+ if (entryPointCopy == internalId && cur_element_count == 1)
917
+ return;
827
918
 
828
- unsigned short int getListCount(linklistsizeint * ptr) const {
829
- return *((unsigned short int *)ptr);
830
- }
919
+ int elemLevel = element_levels_[internalId];
920
+ std::uniform_real_distribution<float> distribution(0.0, 1.0);
921
+ for (int layer = 0; layer <= elemLevel; layer++) {
922
+ std::unordered_set<tableint> sCand;
923
+ std::unordered_set<tableint> sNeigh;
924
+ std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
925
+ if (listOneHop.size() == 0)
926
+ continue;
831
927
 
832
- void setListCount(linklistsizeint * ptr, unsigned short int size) const {
833
- *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
834
- }
928
+ sCand.insert(internalId);
835
929
 
836
- void addPoint(const void *data_point, labeltype label) {
837
- addPoint(data_point, label,-1);
838
- }
930
+ for (auto&& elOneHop : listOneHop) {
931
+ sCand.insert(elOneHop);
839
932
 
840
- void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
841
- // update the feature vector associated with existing point with new vector
842
- memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
843
-
844
- int maxLevelCopy = maxlevel_;
845
- tableint entryPointCopy = enterpoint_node_;
846
- // If point to be updated is entry point and graph just contains single element then just return.
847
- if (entryPointCopy == internalId && cur_element_count == 1)
848
- return;
849
-
850
- int elemLevel = element_levels_[internalId];
851
- std::uniform_real_distribution<float> distribution(0.0, 1.0);
852
- for (int layer = 0; layer <= elemLevel; layer++) {
853
- std::unordered_set<tableint> sCand;
854
- std::unordered_set<tableint> sNeigh;
855
- std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
856
- if (listOneHop.size() == 0)
933
+ if (distribution(update_probability_generator_) > updateNeighborProbability)
857
934
  continue;
858
935
 
859
- sCand.insert(internalId);
936
+ sNeigh.insert(elOneHop);
860
937
 
861
- for (auto&& elOneHop : listOneHop) {
862
- sCand.insert(elOneHop);
863
-
864
- if (distribution(update_probability_generator_) > updateNeighborProbability)
865
- continue;
866
-
867
- sNeigh.insert(elOneHop);
868
-
869
- std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
870
- for (auto&& elTwoHop : listTwoHop) {
871
- sCand.insert(elTwoHop);
872
- }
938
+ std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
939
+ for (auto&& elTwoHop : listTwoHop) {
940
+ sCand.insert(elTwoHop);
873
941
  }
942
+ }
874
943
 
875
- for (auto&& neigh : sNeigh) {
876
- // if (neigh == internalId)
877
- // continue;
944
+ for (auto&& neigh : sNeigh) {
945
+ // if (neigh == internalId)
946
+ // continue;
878
947
 
879
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
880
- size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
881
- size_t elementsToKeep = std::min(ef_construction_, size);
882
- for (auto&& cand : sCand) {
883
- if (cand == neigh)
884
- continue;
885
-
886
- dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
887
- if (candidates.size() < elementsToKeep) {
948
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
949
+ size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
950
+ size_t elementsToKeep = std::min(ef_construction_, size);
951
+ for (auto&& cand : sCand) {
952
+ if (cand == neigh)
953
+ continue;
954
+
955
+ dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
956
+ if (candidates.size() < elementsToKeep) {
957
+ candidates.emplace(distance, cand);
958
+ } else {
959
+ if (distance < candidates.top().first) {
960
+ candidates.pop();
888
961
  candidates.emplace(distance, cand);
889
- } else {
890
- if (distance < candidates.top().first) {
891
- candidates.pop();
892
- candidates.emplace(distance, cand);
893
- }
894
962
  }
895
963
  }
964
+ }
896
965
 
897
- // Retrieve neighbours using heuristic and set connections.
898
- getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
899
-
900
- {
901
- std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
902
- linklistsizeint *ll_cur;
903
- ll_cur = get_linklist_at_level(neigh, layer);
904
- size_t candSize = candidates.size();
905
- setListCount(ll_cur, candSize);
906
- tableint *data = (tableint *) (ll_cur + 1);
907
- for (size_t idx = 0; idx < candSize; idx++) {
908
- data[idx] = candidates.top().second;
909
- candidates.pop();
910
- }
966
+ // Retrieve neighbours using heuristic and set connections.
967
+ getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
968
+
969
+ {
970
+ std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
971
+ linklistsizeint *ll_cur;
972
+ ll_cur = get_linklist_at_level(neigh, layer);
973
+ size_t candSize = candidates.size();
974
+ setListCount(ll_cur, candSize);
975
+ tableint *data = (tableint *) (ll_cur + 1);
976
+ for (size_t idx = 0; idx < candSize; idx++) {
977
+ data[idx] = candidates.top().second;
978
+ candidates.pop();
911
979
  }
912
980
  }
913
981
  }
982
+ }
914
983
 
915
- repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
916
- };
984
+ repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
985
+ }
917
986
 
918
- void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
919
- tableint currObj = entryPointInternalId;
920
- if (dataPointLevel < maxLevel) {
921
- dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
922
- for (int level = maxLevel; level > dataPointLevel; level--) {
923
- bool changed = true;
924
- while (changed) {
925
- changed = false;
926
- unsigned int *data;
927
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
928
- data = get_linklist_at_level(currObj,level);
929
- int size = getListCount(data);
930
- tableint *datal = (tableint *) (data + 1);
987
+
988
+ void repairConnectionsForUpdate(
989
+ const void *dataPoint,
990
+ tableint entryPointInternalId,
991
+ tableint dataPointInternalId,
992
+ int dataPointLevel,
993
+ int maxLevel) {
994
+ tableint currObj = entryPointInternalId;
995
+ if (dataPointLevel < maxLevel) {
996
+ dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
997
+ for (int level = maxLevel; level > dataPointLevel; level--) {
998
+ bool changed = true;
999
+ while (changed) {
1000
+ changed = false;
1001
+ unsigned int *data;
1002
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1003
+ data = get_linklist_at_level(currObj, level);
1004
+ int size = getListCount(data);
1005
+ tableint *datal = (tableint *) (data + 1);
931
1006
  #ifdef USE_SSE
932
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
1007
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
933
1008
  #endif
934
- for (int i = 0; i < size; i++) {
1009
+ for (int i = 0; i < size; i++) {
935
1010
  #ifdef USE_SSE
936
- _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
1011
+ _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
937
1012
  #endif
938
- tableint cand = datal[i];
939
- dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
940
- if (d < curdist) {
941
- curdist = d;
942
- currObj = cand;
943
- changed = true;
944
- }
1013
+ tableint cand = datal[i];
1014
+ dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
1015
+ if (d < curdist) {
1016
+ curdist = d;
1017
+ currObj = cand;
1018
+ changed = true;
945
1019
  }
946
1020
  }
947
1021
  }
948
1022
  }
1023
+ }
949
1024
 
950
- if (dataPointLevel > maxLevel)
951
- throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
952
-
953
- for (int level = dataPointLevel; level >= 0; level--) {
954
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
955
- currObj, dataPoint, level);
1025
+ if (dataPointLevel > maxLevel)
1026
+ throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
956
1027
 
957
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
958
- while (topCandidates.size() > 0) {
959
- if (topCandidates.top().second != dataPointInternalId)
960
- filteredTopCandidates.push(topCandidates.top());
1028
+ for (int level = dataPointLevel; level >= 0; level--) {
1029
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
1030
+ currObj, dataPoint, level);
961
1031
 
962
- topCandidates.pop();
963
- }
1032
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
1033
+ while (topCandidates.size() > 0) {
1034
+ if (topCandidates.top().second != dataPointInternalId)
1035
+ filteredTopCandidates.push(topCandidates.top());
964
1036
 
965
- // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
966
- // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
967
- if (filteredTopCandidates.size() > 0) {
968
- bool epDeleted = isMarkedDeleted(entryPointInternalId);
969
- if (epDeleted) {
970
- filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
971
- if (filteredTopCandidates.size() > ef_construction_)
972
- filteredTopCandidates.pop();
973
- }
1037
+ topCandidates.pop();
1038
+ }
974
1039
 
975
- currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
1040
+ // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
1041
+ // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
1042
+ if (filteredTopCandidates.size() > 0) {
1043
+ bool epDeleted = isMarkedDeleted(entryPointInternalId);
1044
+ if (epDeleted) {
1045
+ filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
1046
+ if (filteredTopCandidates.size() > ef_construction_)
1047
+ filteredTopCandidates.pop();
976
1048
  }
1049
+
1050
+ currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
977
1051
  }
978
1052
  }
1053
+ }
1054
+
979
1055
 
980
- std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
981
- std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
982
- unsigned int *data = get_linklist_at_level(internalId, level);
983
- int size = getListCount(data);
984
- std::vector<tableint> result(size);
985
- tableint *ll = (tableint *) (data + 1);
986
- memcpy(result.data(), ll,size * sizeof(tableint));
987
- return result;
988
- };
989
-
990
- tableint addPoint(const void *data_point, labeltype label, int level) {
991
-
992
- tableint cur_c = 0;
993
- {
994
- // Checking if the element with the same label already exists
995
- // if so, updating it *instead* of creating a new element.
996
- std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
997
- auto search = label_lookup_.find(label);
998
- if (search != label_lookup_.end()) {
999
- tableint existingInternalId = search->second;
1000
- templock_curr.unlock();
1001
-
1002
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
1056
+ std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
1057
+ std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
1058
+ unsigned int *data = get_linklist_at_level(internalId, level);
1059
+ int size = getListCount(data);
1060
+ std::vector<tableint> result(size);
1061
+ tableint *ll = (tableint *) (data + 1);
1062
+ memcpy(result.data(), ll, size * sizeof(tableint));
1063
+ return result;
1064
+ }
1003
1065
 
1066
+
1067
+ tableint addPoint(const void *data_point, labeltype label, int level) {
1068
+ tableint cur_c = 0;
1069
+ {
1070
+ // Checking if the element with the same label already exists
1071
+ // if so, updating it *instead* of creating a new element.
1072
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
1073
+ auto search = label_lookup_.find(label);
1074
+ if (search != label_lookup_.end()) {
1075
+ tableint existingInternalId = search->second;
1076
+ if (allow_replace_deleted_) {
1004
1077
  if (isMarkedDeleted(existingInternalId)) {
1005
- unmarkDeletedInternal(existingInternalId);
1078
+ throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
1006
1079
  }
1007
- updatePoint(data_point, existingInternalId, 1.0);
1008
-
1009
- return existingInternalId;
1010
1080
  }
1081
+ lock_table.unlock();
1011
1082
 
1012
- if (cur_element_count >= max_elements_) {
1013
- throw std::runtime_error("The number of elements exceeds the specified limit");
1014
- };
1083
+ if (isMarkedDeleted(existingInternalId)) {
1084
+ unmarkDeletedInternal(existingInternalId);
1085
+ }
1086
+ updatePoint(data_point, existingInternalId, 1.0);
1015
1087
 
1016
- cur_c = cur_element_count;
1017
- cur_element_count++;
1018
- label_lookup_[label] = cur_c;
1088
+ return existingInternalId;
1019
1089
  }
1020
1090
 
1021
- // Take update lock to prevent race conditions on an element with insertion/update at the same time.
1022
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
1023
- std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1024
- int curlevel = getRandomLevel(mult_);
1025
- if (level > 0)
1026
- curlevel = level;
1091
+ if (cur_element_count >= max_elements_) {
1092
+ throw std::runtime_error("The number of elements exceeds the specified limit");
1093
+ }
1027
1094
 
1028
- element_levels_[cur_c] = curlevel;
1095
+ cur_c = cur_element_count;
1096
+ cur_element_count++;
1097
+ label_lookup_[label] = cur_c;
1098
+ }
1029
1099
 
1100
+ std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1101
+ int curlevel = getRandomLevel(mult_);
1102
+ if (level > 0)
1103
+ curlevel = level;
1030
1104
 
1031
- std::unique_lock <std::mutex> templock(global);
1032
- int maxlevelcopy = maxlevel_;
1033
- if (curlevel <= maxlevelcopy)
1034
- templock.unlock();
1035
- tableint currObj = enterpoint_node_;
1036
- tableint enterpoint_copy = enterpoint_node_;
1105
+ element_levels_[cur_c] = curlevel;
1037
1106
 
1107
+ std::unique_lock <std::mutex> templock(global);
1108
+ int maxlevelcopy = maxlevel_;
1109
+ if (curlevel <= maxlevelcopy)
1110
+ templock.unlock();
1111
+ tableint currObj = enterpoint_node_;
1112
+ tableint enterpoint_copy = enterpoint_node_;
1038
1113
 
1039
- memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1114
+ memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1040
1115
 
1041
- // Initialisation of the data and label
1042
- memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1043
- memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1116
+ // Initialisation of the data and label
1117
+ memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1118
+ memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1044
1119
 
1120
+ if (curlevel) {
1121
+ linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1122
+ if (linkLists_[cur_c] == nullptr)
1123
+ throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1124
+ memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1125
+ }
1045
1126
 
1046
- if (curlevel) {
1047
- linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1048
- if (linkLists_[cur_c] == nullptr)
1049
- throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1050
- memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1051
- }
1127
+ if ((signed)currObj != -1) {
1128
+ if (curlevel < maxlevelcopy) {
1129
+ dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1130
+ for (int level = maxlevelcopy; level > curlevel; level--) {
1131
+ bool changed = true;
1132
+ while (changed) {
1133
+ changed = false;
1134
+ unsigned int *data;
1135
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1136
+ data = get_linklist(currObj, level);
1137
+ int size = getListCount(data);
1052
1138
 
1053
- if ((signed)currObj != -1) {
1054
-
1055
- if (curlevel < maxlevelcopy) {
1056
-
1057
- dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1058
- for (int level = maxlevelcopy; level > curlevel; level--) {
1059
-
1060
-
1061
- bool changed = true;
1062
- while (changed) {
1063
- changed = false;
1064
- unsigned int *data;
1065
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1066
- data = get_linklist(currObj,level);
1067
- int size = getListCount(data);
1068
-
1069
- tableint *datal = (tableint *) (data + 1);
1070
- for (int i = 0; i < size; i++) {
1071
- tableint cand = datal[i];
1072
- if (cand < 0 || cand > max_elements_)
1073
- throw std::runtime_error("cand error");
1074
- dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1075
- if (d < curdist) {
1076
- curdist = d;
1077
- currObj = cand;
1078
- changed = true;
1079
- }
1139
+ tableint *datal = (tableint *) (data + 1);
1140
+ for (int i = 0; i < size; i++) {
1141
+ tableint cand = datal[i];
1142
+ if (cand < 0 || cand > max_elements_)
1143
+ throw std::runtime_error("cand error");
1144
+ dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1145
+ if (d < curdist) {
1146
+ curdist = d;
1147
+ currObj = cand;
1148
+ changed = true;
1080
1149
  }
1081
1150
  }
1082
1151
  }
1083
1152
  }
1153
+ }
1084
1154
 
1085
- bool epDeleted = isMarkedDeleted(enterpoint_copy);
1086
- for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1087
- if (level > maxlevelcopy || level < 0) // possible?
1088
- throw std::runtime_error("Level error");
1089
-
1090
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1091
- currObj, data_point, level);
1092
- if (epDeleted) {
1093
- top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1094
- if (top_candidates.size() > ef_construction_)
1095
- top_candidates.pop();
1096
- }
1097
- currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1155
+ bool epDeleted = isMarkedDeleted(enterpoint_copy);
1156
+ for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1157
+ if (level > maxlevelcopy || level < 0) // possible?
1158
+ throw std::runtime_error("Level error");
1159
+
1160
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1161
+ currObj, data_point, level);
1162
+ if (epDeleted) {
1163
+ top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1164
+ if (top_candidates.size() > ef_construction_)
1165
+ top_candidates.pop();
1098
1166
  }
1099
-
1100
-
1101
- } else {
1102
- // Do nothing for the first element
1103
- enterpoint_node_ = 0;
1104
- maxlevel_ = curlevel;
1105
-
1167
+ currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1106
1168
  }
1169
+ } else {
1170
+ // Do nothing for the first element
1171
+ enterpoint_node_ = 0;
1172
+ maxlevel_ = curlevel;
1173
+ }
1107
1174
 
1108
- //Releasing lock for the maximum level
1109
- if (curlevel > maxlevelcopy) {
1110
- enterpoint_node_ = cur_c;
1111
- maxlevel_ = curlevel;
1112
- }
1113
- return cur_c;
1114
- };
1175
+ // Releasing lock for the maximum level
1176
+ if (curlevel > maxlevelcopy) {
1177
+ enterpoint_node_ = cur_c;
1178
+ maxlevel_ = curlevel;
1179
+ }
1180
+ return cur_c;
1181
+ }
1115
1182
 
1116
- std::priority_queue<std::pair<dist_t, labeltype >>
1117
- searchKnn(const void *query_data, size_t k) const {
1118
- std::priority_queue<std::pair<dist_t, labeltype >> result;
1119
- if (cur_element_count == 0) return result;
1120
1183
 
1121
- tableint currObj = enterpoint_node_;
1122
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1184
+ std::priority_queue<std::pair<dist_t, labeltype >>
1185
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
1186
+ std::priority_queue<std::pair<dist_t, labeltype >> result;
1187
+ if (cur_element_count == 0) return result;
1123
1188
 
1124
- for (int level = maxlevel_; level > 0; level--) {
1125
- bool changed = true;
1126
- while (changed) {
1127
- changed = false;
1128
- unsigned int *data;
1189
+ tableint currObj = enterpoint_node_;
1190
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1129
1191
 
1130
- data = (unsigned int *) get_linklist(currObj, level);
1131
- int size = getListCount(data);
1132
- metric_hops++;
1133
- metric_distance_computations+=size;
1192
+ for (int level = maxlevel_; level > 0; level--) {
1193
+ bool changed = true;
1194
+ while (changed) {
1195
+ changed = false;
1196
+ unsigned int *data;
1134
1197
 
1135
- tableint *datal = (tableint *) (data + 1);
1136
- for (int i = 0; i < size; i++) {
1137
- tableint cand = datal[i];
1138
- if (cand < 0 || cand > max_elements_)
1139
- throw std::runtime_error("cand error");
1140
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1198
+ data = (unsigned int *) get_linklist(currObj, level);
1199
+ int size = getListCount(data);
1200
+ metric_hops++;
1201
+ metric_distance_computations+=size;
1141
1202
 
1142
- if (d < curdist) {
1143
- curdist = d;
1144
- currObj = cand;
1145
- changed = true;
1146
- }
1203
+ tableint *datal = (tableint *) (data + 1);
1204
+ for (int i = 0; i < size; i++) {
1205
+ tableint cand = datal[i];
1206
+ if (cand < 0 || cand > max_elements_)
1207
+ throw std::runtime_error("cand error");
1208
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1209
+
1210
+ if (d < curdist) {
1211
+ curdist = d;
1212
+ currObj = cand;
1213
+ changed = true;
1147
1214
  }
1148
1215
  }
1149
1216
  }
1217
+ }
1150
1218
 
1151
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1152
- if (num_deleted_) {
1153
- top_candidates=searchBaseLayerST<true,true>(
1154
- currObj, query_data, std::max(ef_, k));
1155
- }
1156
- else{
1157
- top_candidates=searchBaseLayerST<false,true>(
1158
- currObj, query_data, std::max(ef_, k));
1159
- }
1219
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1220
+ if (num_deleted_) {
1221
+ top_candidates = searchBaseLayerST<true, true>(
1222
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1223
+ } else {
1224
+ top_candidates = searchBaseLayerST<false, true>(
1225
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
1226
+ }
1160
1227
 
1161
- while (top_candidates.size() > k) {
1162
- top_candidates.pop();
1163
- }
1164
- while (top_candidates.size() > 0) {
1165
- std::pair<dist_t, tableint> rez = top_candidates.top();
1166
- result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1167
- top_candidates.pop();
1168
- }
1169
- return result;
1170
- };
1171
-
1172
- void checkIntegrity(){
1173
- int connections_checked=0;
1174
- std::vector <int > inbound_connections_num(cur_element_count,0);
1175
- for(int i = 0;i < cur_element_count; i++){
1176
- for(int l = 0;l <= element_levels_[i]; l++){
1177
- linklistsizeint *ll_cur = get_linklist_at_level(i,l);
1178
- int size = getListCount(ll_cur);
1179
- tableint *data = (tableint *) (ll_cur + 1);
1180
- std::unordered_set<tableint> s;
1181
- for (int j=0; j<size; j++){
1182
- assert(data[j] > 0);
1183
- assert(data[j] < cur_element_count);
1184
- assert (data[j] != i);
1185
- inbound_connections_num[data[j]]++;
1186
- s.insert(data[j]);
1187
- connections_checked++;
1228
+ while (top_candidates.size() > k) {
1229
+ top_candidates.pop();
1230
+ }
1231
+ while (top_candidates.size() > 0) {
1232
+ std::pair<dist_t, tableint> rez = top_candidates.top();
1233
+ result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1234
+ top_candidates.pop();
1235
+ }
1236
+ return result;
1237
+ }
1188
1238
 
1189
- }
1190
- assert(s.size() == size);
1239
+
1240
+ void checkIntegrity() {
1241
+ int connections_checked = 0;
1242
+ std::vector <int > inbound_connections_num(cur_element_count, 0);
1243
+ for (int i = 0; i < cur_element_count; i++) {
1244
+ for (int l = 0; l <= element_levels_[i]; l++) {
1245
+ linklistsizeint *ll_cur = get_linklist_at_level(i, l);
1246
+ int size = getListCount(ll_cur);
1247
+ tableint *data = (tableint *) (ll_cur + 1);
1248
+ std::unordered_set<tableint> s;
1249
+ for (int j = 0; j < size; j++) {
1250
+ assert(data[j] > 0);
1251
+ assert(data[j] < cur_element_count);
1252
+ assert(data[j] != i);
1253
+ inbound_connections_num[data[j]]++;
1254
+ s.insert(data[j]);
1255
+ connections_checked++;
1191
1256
  }
1257
+ assert(s.size() == size);
1192
1258
  }
1193
- if(cur_element_count > 1){
1194
- int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
1195
- for(int i=0; i < cur_element_count; i++){
1196
- assert(inbound_connections_num[i] > 0);
1197
- min1=std::min(inbound_connections_num[i],min1);
1198
- max1=std::max(inbound_connections_num[i],max1);
1199
- }
1200
- std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1259
+ }
1260
+ if (cur_element_count > 1) {
1261
+ int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
1262
+ for (int i=0; i < cur_element_count; i++) {
1263
+ assert(inbound_connections_num[i] > 0);
1264
+ min1 = std::min(inbound_connections_num[i], min1);
1265
+ max1 = std::max(inbound_connections_num[i], max1);
1201
1266
  }
1202
- std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1203
-
1267
+ std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1204
1268
  }
1205
-
1206
- };
1207
-
1208
- }
1269
+ std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1270
+ }
1271
+ };
1272
+ } // namespace hnswlib