hnswlib 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,152 @@
1
+ #pragma once
2
+ #include <unordered_map>
3
+ #include <fstream>
4
+ #include <mutex>
5
+ #include <algorithm>
6
+
7
+ namespace hnswlib {
8
+ template<typename dist_t>
9
+ class BruteforceSearch : public AlgorithmInterface<dist_t> {
10
+ public:
11
+ BruteforceSearch(SpaceInterface <dist_t> *s) {
12
+
13
+ }
14
+ BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location) {
15
+ loadIndex(location, s);
16
+ }
17
+
18
+ BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
19
+ maxelements_ = maxElements;
20
+ data_size_ = s->get_data_size();
21
+ fstdistfunc_ = s->get_dist_func();
22
+ dist_func_param_ = s->get_dist_func_param();
23
+ size_per_element_ = data_size_ + sizeof(labeltype);
24
+ data_ = (char *) malloc(maxElements * size_per_element_);
25
+ if (data_ == nullptr)
26
+ std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
27
+ cur_element_count = 0;
28
+ }
29
+
30
+ ~BruteforceSearch() {
31
+ free(data_);
32
+ }
33
+
34
+ char *data_;
35
+ size_t maxelements_;
36
+ size_t cur_element_count;
37
+ size_t size_per_element_;
38
+
39
+ size_t data_size_;
40
+ DISTFUNC <dist_t> fstdistfunc_;
41
+ void *dist_func_param_;
42
+ std::mutex index_lock;
43
+
44
+ std::unordered_map<labeltype,size_t > dict_external_to_internal;
45
+
46
+ void addPoint(const void *datapoint, labeltype label) {
47
+
48
+ int idx;
49
+ {
50
+ std::unique_lock<std::mutex> lock(index_lock);
51
+
52
+
53
+
54
+ auto search=dict_external_to_internal.find(label);
55
+ if (search != dict_external_to_internal.end()) {
56
+ idx=search->second;
57
+ }
58
+ else{
59
+ if (cur_element_count >= maxelements_) {
60
+ throw std::runtime_error("The number of elements exceeds the specified limit\n");
61
+ }
62
+ idx=cur_element_count;
63
+ dict_external_to_internal[label] = idx;
64
+ cur_element_count++;
65
+ }
66
+ }
67
+ memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
68
+ memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
69
+
70
+
71
+
72
+
73
+ };
74
+
75
+ void removePoint(labeltype cur_external) {
76
+ size_t cur_c=dict_external_to_internal[cur_external];
77
+
78
+ dict_external_to_internal.erase(cur_external);
79
+
80
+ labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
81
+ dict_external_to_internal[label]=cur_c;
82
+ memcpy(data_ + size_per_element_ * cur_c,
83
+ data_ + size_per_element_ * (cur_element_count-1),
84
+ data_size_+sizeof(labeltype));
85
+ cur_element_count--;
86
+
87
+ }
88
+
89
+
90
+ std::priority_queue<std::pair<dist_t, labeltype >>
91
+ searchKnn(const void *query_data, size_t k) const {
92
+ std::priority_queue<std::pair<dist_t, labeltype >> topResults;
93
+ if (cur_element_count == 0) return topResults;
94
+ for (int i = 0; i < k; i++) {
95
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
96
+ topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
97
+ data_size_))));
98
+ }
99
+ dist_t lastdist = topResults.top().first;
100
+ for (int i = k; i < cur_element_count; i++) {
101
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
102
+ if (dist <= lastdist) {
103
+ topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
104
+ data_size_))));
105
+ if (topResults.size() > k)
106
+ topResults.pop();
107
+ lastdist = topResults.top().first;
108
+ }
109
+
110
+ }
111
+ return topResults;
112
+ };
113
+
114
+ void saveIndex(const std::string &location) {
115
+ std::ofstream output(location, std::ios::binary);
116
+ std::streampos position;
117
+
118
+ writeBinaryPOD(output, maxelements_);
119
+ writeBinaryPOD(output, size_per_element_);
120
+ writeBinaryPOD(output, cur_element_count);
121
+
122
+ output.write(data_, maxelements_ * size_per_element_);
123
+
124
+ output.close();
125
+ }
126
+
127
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
128
+
129
+
130
+ std::ifstream input(location, std::ios::binary);
131
+ std::streampos position;
132
+
133
+ readBinaryPOD(input, maxelements_);
134
+ readBinaryPOD(input, size_per_element_);
135
+ readBinaryPOD(input, cur_element_count);
136
+
137
+ data_size_ = s->get_data_size();
138
+ fstdistfunc_ = s->get_dist_func();
139
+ dist_func_param_ = s->get_dist_func_param();
140
+ size_per_element_ = data_size_ + sizeof(labeltype);
141
+ data_ = (char *) malloc(maxelements_ * size_per_element_);
142
+ if (data_ == nullptr)
143
+ std::runtime_error("Not enough memory: loadIndex failed to allocate data");
144
+
145
+ input.read(data_, maxelements_ * size_per_element_);
146
+
147
+ input.close();
148
+
149
+ }
150
+
151
+ };
152
+ }
@@ -0,0 +1,1192 @@
1
+ #pragma once
2
+
3
+ #include "visited_list_pool.h"
4
+ #include "hnswlib.h"
5
+ #include <atomic>
6
+ #include <random>
7
+ #include <stdlib.h>
8
+ #include <assert.h>
9
+ #include <unordered_set>
10
+ #include <list>
11
+
12
+ namespace hnswlib {
13
+ typedef unsigned int tableint;
14
+ typedef unsigned int linklistsizeint;
15
+
16
+ template<typename dist_t>
17
+ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
18
+ public:
19
+ static const tableint max_update_element_locks = 65536;
20
+ HierarchicalNSW(SpaceInterface<dist_t> *s) {
21
+
22
+ }
23
+
24
+ HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
25
+ loadIndex(location, s, max_elements);
26
+ }
27
+
28
+ HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
29
+ link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
30
+ max_elements_ = max_elements;
31
+
32
+ has_deletions_=false;
33
+ data_size_ = s->get_data_size();
34
+ fstdistfunc_ = s->get_dist_func();
35
+ dist_func_param_ = s->get_dist_func_param();
36
+ M_ = M;
37
+ maxM_ = M_;
38
+ maxM0_ = M_ * 2;
39
+ ef_construction_ = std::max(ef_construction,M_);
40
+ ef_ = 10;
41
+
42
+ level_generator_.seed(random_seed);
43
+ update_probability_generator_.seed(random_seed + 1);
44
+
45
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
46
+ size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
47
+ offsetData_ = size_links_level0_;
48
+ label_offset_ = size_links_level0_ + data_size_;
49
+ offsetLevel0_ = 0;
50
+
51
+ data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
52
+ if (data_level0_memory_ == nullptr)
53
+ throw std::runtime_error("Not enough memory");
54
+
55
+ cur_element_count = 0;
56
+
57
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
58
+
59
+
60
+
61
+ //initializations for special treatment of the first node
62
+ enterpoint_node_ = -1;
63
+ maxlevel_ = -1;
64
+
65
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
66
+ if (linkLists_ == nullptr)
67
+ throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
68
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
69
+ mult_ = 1 / log(1.0 * M_);
70
+ revSize_ = 1.0 / mult_;
71
+ }
72
+
73
+ struct CompareByFirst {
74
+ constexpr bool operator()(std::pair<dist_t, tableint> const &a,
75
+ std::pair<dist_t, tableint> const &b) const noexcept {
76
+ return a.first < b.first;
77
+ }
78
+ };
79
+
80
+ ~HierarchicalNSW() {
81
+
82
+ free(data_level0_memory_);
83
+ for (tableint i = 0; i < cur_element_count; i++) {
84
+ if (element_levels_[i] > 0)
85
+ free(linkLists_[i]);
86
+ }
87
+ free(linkLists_);
88
+ delete visited_list_pool_;
89
+ }
90
+
91
+ size_t max_elements_;
92
+ size_t cur_element_count;
93
+ size_t size_data_per_element_;
94
+ size_t size_links_per_element_;
95
+
96
+ size_t M_;
97
+ size_t maxM_;
98
+ size_t maxM0_;
99
+ size_t ef_construction_;
100
+
101
+ double mult_, revSize_;
102
+ int maxlevel_;
103
+
104
+
105
+ VisitedListPool *visited_list_pool_;
106
+ std::mutex cur_element_count_guard_;
107
+
108
+ std::vector<std::mutex> link_list_locks_;
109
+
110
+ // Locks to prevent race condition during update/insert of an element at same time.
111
+ // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
112
+ std::vector<std::mutex> link_list_update_locks_;
113
+ tableint enterpoint_node_;
114
+
115
+
116
+ size_t size_links_level0_;
117
+ size_t offsetData_, offsetLevel0_;
118
+
119
+
120
+ char *data_level0_memory_;
121
+ char **linkLists_;
122
+ std::vector<int> element_levels_;
123
+
124
+ size_t data_size_;
125
+
126
+ bool has_deletions_;
127
+
128
+
129
+ size_t label_offset_;
130
+ DISTFUNC<dist_t> fstdistfunc_;
131
+ void *dist_func_param_;
132
+ std::unordered_map<labeltype, tableint> label_lookup_;
133
+
134
+ std::default_random_engine level_generator_;
135
+ std::default_random_engine update_probability_generator_;
136
+
137
+ inline labeltype getExternalLabel(tableint internal_id) const {
138
+ labeltype return_label;
139
+ memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
140
+ return return_label;
141
+ }
142
+
143
+ inline void setExternalLabel(tableint internal_id, labeltype label) const {
144
+ memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
145
+ }
146
+
147
+ inline labeltype *getExternalLabeLp(tableint internal_id) const {
148
+ return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
149
+ }
150
+
151
+ inline char *getDataByInternalId(tableint internal_id) const {
152
+ return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
153
+ }
154
+
155
+ int getRandomLevel(double reverse_size) {
156
+ std::uniform_real_distribution<double> distribution(0.0, 1.0);
157
+ double r = -log(distribution(level_generator_)) * reverse_size;
158
+ return (int) r;
159
+ }
160
+
161
+
162
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
163
+ searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
164
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
165
+ vl_type *visited_array = vl->mass;
166
+ vl_type visited_array_tag = vl->curV;
167
+
168
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
169
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
170
+
171
+ dist_t lowerBound;
172
+ if (!isMarkedDeleted(ep_id)) {
173
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
174
+ top_candidates.emplace(dist, ep_id);
175
+ lowerBound = dist;
176
+ candidateSet.emplace(-dist, ep_id);
177
+ } else {
178
+ lowerBound = std::numeric_limits<dist_t>::max();
179
+ candidateSet.emplace(-lowerBound, ep_id);
180
+ }
181
+ visited_array[ep_id] = visited_array_tag;
182
+
183
+ while (!candidateSet.empty()) {
184
+ std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
185
+ if ((-curr_el_pair.first) > lowerBound) {
186
+ break;
187
+ }
188
+ candidateSet.pop();
189
+
190
+ tableint curNodeNum = curr_el_pair.second;
191
+
192
+ std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
193
+
194
+ int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
195
+ if (layer == 0) {
196
+ data = (int*)get_linklist0(curNodeNum);
197
+ } else {
198
+ data = (int*)get_linklist(curNodeNum, layer);
199
+ // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
200
+ }
201
+ size_t size = getListCount((linklistsizeint*)data);
202
+ tableint *datal = (tableint *) (data + 1);
203
+ #ifdef USE_SSE
204
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
205
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
206
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
207
+ _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
208
+ #endif
209
+
210
+ for (size_t j = 0; j < size; j++) {
211
+ tableint candidate_id = *(datal + j);
212
+ // if (candidate_id == 0) continue;
213
+ #ifdef USE_SSE
214
+ _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
215
+ _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
216
+ #endif
217
+ if (visited_array[candidate_id] == visited_array_tag) continue;
218
+ visited_array[candidate_id] = visited_array_tag;
219
+ char *currObj1 = (getDataByInternalId(candidate_id));
220
+
221
+ dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
222
+ if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
223
+ candidateSet.emplace(-dist1, candidate_id);
224
+ #ifdef USE_SSE
225
+ _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
226
+ #endif
227
+
228
+ if (!isMarkedDeleted(candidate_id))
229
+ top_candidates.emplace(dist1, candidate_id);
230
+
231
+ if (top_candidates.size() > ef_construction_)
232
+ top_candidates.pop();
233
+
234
+ if (!top_candidates.empty())
235
+ lowerBound = top_candidates.top().first;
236
+ }
237
+ }
238
+ }
239
+ visited_list_pool_->releaseVisitedList(vl);
240
+
241
+ return top_candidates;
242
+ }
243
+
244
+ mutable std::atomic<long> metric_distance_computations;
245
+ mutable std::atomic<long> metric_hops;
246
+
247
+ template <bool has_deletions, bool collect_metrics=false>
248
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
249
+ searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
250
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
251
+ vl_type *visited_array = vl->mass;
252
+ vl_type visited_array_tag = vl->curV;
253
+
254
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
255
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
256
+
257
+ dist_t lowerBound;
258
+ if (!has_deletions || !isMarkedDeleted(ep_id)) {
259
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
260
+ lowerBound = dist;
261
+ top_candidates.emplace(dist, ep_id);
262
+ candidate_set.emplace(-dist, ep_id);
263
+ } else {
264
+ lowerBound = std::numeric_limits<dist_t>::max();
265
+ candidate_set.emplace(-lowerBound, ep_id);
266
+ }
267
+
268
+ visited_array[ep_id] = visited_array_tag;
269
+
270
+ while (!candidate_set.empty()) {
271
+
272
+ std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
273
+
274
+ if ((-current_node_pair.first) > lowerBound) {
275
+ break;
276
+ }
277
+ candidate_set.pop();
278
+
279
+ tableint current_node_id = current_node_pair.second;
280
+ int *data = (int *) get_linklist0(current_node_id);
281
+ size_t size = getListCount((linklistsizeint*)data);
282
+ // bool cur_node_deleted = isMarkedDeleted(current_node_id);
283
+ if(collect_metrics){
284
+ metric_hops++;
285
+ metric_distance_computations+=size;
286
+ }
287
+
288
+ #ifdef USE_SSE
289
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
290
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
291
+ _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
292
+ _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
293
+ #endif
294
+
295
+ for (size_t j = 1; j <= size; j++) {
296
+ int candidate_id = *(data + j);
297
+ // if (candidate_id == 0) continue;
298
+ #ifdef USE_SSE
299
+ _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
300
+ _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
301
+ _MM_HINT_T0);////////////
302
+ #endif
303
+ if (!(visited_array[candidate_id] == visited_array_tag)) {
304
+
305
+ visited_array[candidate_id] = visited_array_tag;
306
+
307
+ char *currObj1 = (getDataByInternalId(candidate_id));
308
+ dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
309
+
310
+ if (top_candidates.size() < ef || lowerBound > dist) {
311
+ candidate_set.emplace(-dist, candidate_id);
312
+ #ifdef USE_SSE
313
+ _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
314
+ offsetLevel0_,///////////
315
+ _MM_HINT_T0);////////////////////////
316
+ #endif
317
+
318
+ if (!has_deletions || !isMarkedDeleted(candidate_id))
319
+ top_candidates.emplace(dist, candidate_id);
320
+
321
+ if (top_candidates.size() > ef)
322
+ top_candidates.pop();
323
+
324
+ if (!top_candidates.empty())
325
+ lowerBound = top_candidates.top().first;
326
+ }
327
+ }
328
+ }
329
+ }
330
+
331
+ visited_list_pool_->releaseVisitedList(vl);
332
+ return top_candidates;
333
+ }
334
+
335
+ void getNeighborsByHeuristic2(
336
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
337
+ const size_t M) {
338
+ if (top_candidates.size() < M) {
339
+ return;
340
+ }
341
+
342
+ std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
343
+ std::vector<std::pair<dist_t, tableint>> return_list;
344
+ while (top_candidates.size() > 0) {
345
+ queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
346
+ top_candidates.pop();
347
+ }
348
+
349
+ while (queue_closest.size()) {
350
+ if (return_list.size() >= M)
351
+ break;
352
+ std::pair<dist_t, tableint> curent_pair = queue_closest.top();
353
+ dist_t dist_to_query = -curent_pair.first;
354
+ queue_closest.pop();
355
+ bool good = true;
356
+
357
+ for (std::pair<dist_t, tableint> second_pair : return_list) {
358
+ dist_t curdist =
359
+ fstdistfunc_(getDataByInternalId(second_pair.second),
360
+ getDataByInternalId(curent_pair.second),
361
+ dist_func_param_);;
362
+ if (curdist < dist_to_query) {
363
+ good = false;
364
+ break;
365
+ }
366
+ }
367
+ if (good) {
368
+ return_list.push_back(curent_pair);
369
+ }
370
+ }
371
+
372
+ for (std::pair<dist_t, tableint> curent_pair : return_list) {
373
+ top_candidates.emplace(-curent_pair.first, curent_pair.second);
374
+ }
375
+ }
376
+
377
+
378
+ linklistsizeint *get_linklist0(tableint internal_id) const {
379
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
380
+ };
381
+
382
+ linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
383
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
384
+ };
385
+
386
+ linklistsizeint *get_linklist(tableint internal_id, int level) const {
387
+ return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
388
+ };
389
+
390
+ linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
391
+ return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
392
+ };
393
+
394
+ tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
395
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
396
+ int level, bool isUpdate) {
397
+ size_t Mcurmax = level ? maxM_ : maxM0_;
398
+ getNeighborsByHeuristic2(top_candidates, M_);
399
+ if (top_candidates.size() > M_)
400
+ throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
401
+
402
+ std::vector<tableint> selectedNeighbors;
403
+ selectedNeighbors.reserve(M_);
404
+ while (top_candidates.size() > 0) {
405
+ selectedNeighbors.push_back(top_candidates.top().second);
406
+ top_candidates.pop();
407
+ }
408
+
409
+ tableint next_closest_entry_point = selectedNeighbors.back();
410
+
411
+ {
412
+ linklistsizeint *ll_cur;
413
+ if (level == 0)
414
+ ll_cur = get_linklist0(cur_c);
415
+ else
416
+ ll_cur = get_linklist(cur_c, level);
417
+
418
+ if (*ll_cur && !isUpdate) {
419
+ throw std::runtime_error("The newly inserted element should have blank link list");
420
+ }
421
+ setListCount(ll_cur,selectedNeighbors.size());
422
+ tableint *data = (tableint *) (ll_cur + 1);
423
+ for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
424
+ if (data[idx] && !isUpdate)
425
+ throw std::runtime_error("Possible memory corruption");
426
+ if (level > element_levels_[selectedNeighbors[idx]])
427
+ throw std::runtime_error("Trying to make a link on a non-existent level");
428
+
429
+ data[idx] = selectedNeighbors[idx];
430
+
431
+ }
432
+ }
433
+
434
+ for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
435
+
436
+ std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
437
+
438
+ linklistsizeint *ll_other;
439
+ if (level == 0)
440
+ ll_other = get_linklist0(selectedNeighbors[idx]);
441
+ else
442
+ ll_other = get_linklist(selectedNeighbors[idx], level);
443
+
444
+ size_t sz_link_list_other = getListCount(ll_other);
445
+
446
+ if (sz_link_list_other > Mcurmax)
447
+ throw std::runtime_error("Bad value of sz_link_list_other");
448
+ if (selectedNeighbors[idx] == cur_c)
449
+ throw std::runtime_error("Trying to connect an element to itself");
450
+ if (level > element_levels_[selectedNeighbors[idx]])
451
+ throw std::runtime_error("Trying to make a link on a non-existent level");
452
+
453
+ tableint *data = (tableint *) (ll_other + 1);
454
+
455
+ bool is_cur_c_present = false;
456
+ if (isUpdate) {
457
+ for (size_t j = 0; j < sz_link_list_other; j++) {
458
+ if (data[j] == cur_c) {
459
+ is_cur_c_present = true;
460
+ break;
461
+ }
462
+ }
463
+ }
464
+
465
+ // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
466
+ if (!is_cur_c_present) {
467
+ if (sz_link_list_other < Mcurmax) {
468
+ data[sz_link_list_other] = cur_c;
469
+ setListCount(ll_other, sz_link_list_other + 1);
470
+ } else {
471
+ // finding the "weakest" element to replace it with the new one
472
+ dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
473
+ dist_func_param_);
474
+ // Heuristic:
475
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
476
+ candidates.emplace(d_max, cur_c);
477
+
478
+ for (size_t j = 0; j < sz_link_list_other; j++) {
479
+ candidates.emplace(
480
+ fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
481
+ dist_func_param_), data[j]);
482
+ }
483
+
484
+ getNeighborsByHeuristic2(candidates, Mcurmax);
485
+
486
+ int indx = 0;
487
+ while (candidates.size() > 0) {
488
+ data[indx] = candidates.top().second;
489
+ candidates.pop();
490
+ indx++;
491
+ }
492
+
493
+ setListCount(ll_other, indx);
494
+ // Nearest K:
495
+ /*int indx = -1;
496
+ for (int j = 0; j < sz_link_list_other; j++) {
497
+ dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
498
+ if (d > d_max) {
499
+ indx = j;
500
+ d_max = d;
501
+ }
502
+ }
503
+ if (indx >= 0) {
504
+ data[indx] = cur_c;
505
+ } */
506
+ }
507
+ }
508
+ }
509
+
510
+ return next_closest_entry_point;
511
+ }
512
+
513
+ std::mutex global;
514
+ size_t ef_;
515
+
516
+ void setEf(size_t ef) {
517
+ ef_ = ef;
518
+ }
519
+
520
+
521
+ std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
522
+ std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
523
+ if (cur_element_count == 0) return top_candidates;
524
+ tableint currObj = enterpoint_node_;
525
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
526
+
527
+ for (size_t level = maxlevel_; level > 0; level--) {
528
+ bool changed = true;
529
+ while (changed) {
530
+ changed = false;
531
+ int *data;
532
+ data = (int *) get_linklist(currObj,level);
533
+ int size = getListCount(data);
534
+ tableint *datal = (tableint *) (data + 1);
535
+ for (int i = 0; i < size; i++) {
536
+ tableint cand = datal[i];
537
+ if (cand < 0 || cand > max_elements_)
538
+ throw std::runtime_error("cand error");
539
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
540
+
541
+ if (d < curdist) {
542
+ curdist = d;
543
+ currObj = cand;
544
+ changed = true;
545
+ }
546
+ }
547
+ }
548
+ }
549
+
550
+ if (has_deletions_) {
551
+ std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
552
+ ef_);
553
+ top_candidates.swap(top_candidates1);
554
+ }
555
+ else{
556
+ std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
557
+ ef_);
558
+ top_candidates.swap(top_candidates1);
559
+ }
560
+
561
+ while (top_candidates.size() > k) {
562
+ top_candidates.pop();
563
+ }
564
+ return top_candidates;
565
+ };
566
+
567
+ void resizeIndex(size_t new_max_elements){
568
+ if (new_max_elements<cur_element_count)
569
+ throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
570
+
571
+
572
+ delete visited_list_pool_;
573
+ visited_list_pool_ = new VisitedListPool(1, new_max_elements);
574
+
575
+
576
+ element_levels_.resize(new_max_elements);
577
+
578
+ std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
579
+
580
+ // Reallocate base layer
581
+ char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
582
+ if (data_level0_memory_new == nullptr)
583
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
584
+ data_level0_memory_ = data_level0_memory_new;
585
+
586
+ // Reallocate all other layers
587
+ char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
588
+ if (linkLists_new == nullptr)
589
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
590
+ linkLists_ = linkLists_new;
591
+
592
+ max_elements_ = new_max_elements;
593
+ }
594
+
595
+ void saveIndex(const std::string &location) {
596
+ std::ofstream output(location, std::ios::binary);
597
+ std::streampos position;
598
+
599
+ writeBinaryPOD(output, offsetLevel0_);
600
+ writeBinaryPOD(output, max_elements_);
601
+ writeBinaryPOD(output, cur_element_count);
602
+ writeBinaryPOD(output, size_data_per_element_);
603
+ writeBinaryPOD(output, label_offset_);
604
+ writeBinaryPOD(output, offsetData_);
605
+ writeBinaryPOD(output, maxlevel_);
606
+ writeBinaryPOD(output, enterpoint_node_);
607
+ writeBinaryPOD(output, maxM_);
608
+
609
+ writeBinaryPOD(output, maxM0_);
610
+ writeBinaryPOD(output, M_);
611
+ writeBinaryPOD(output, mult_);
612
+ writeBinaryPOD(output, ef_construction_);
613
+
614
+ output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
615
+
616
+ for (size_t i = 0; i < cur_element_count; i++) {
617
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
618
+ writeBinaryPOD(output, linkListSize);
619
+ if (linkListSize)
620
+ output.write(linkLists_[i], linkListSize);
621
+ }
622
+ output.close();
623
+ }
624
+
625
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
626
+
627
+
628
+ std::ifstream input(location, std::ios::binary);
629
+
630
+ if (!input.is_open())
631
+ throw std::runtime_error("Cannot open file");
632
+
633
+ // get file size:
634
+ input.seekg(0,input.end);
635
+ std::streampos total_filesize=input.tellg();
636
+ input.seekg(0,input.beg);
637
+
638
+ readBinaryPOD(input, offsetLevel0_);
639
+ readBinaryPOD(input, max_elements_);
640
+ readBinaryPOD(input, cur_element_count);
641
+
642
+ size_t max_elements=max_elements_i;
643
+ if(max_elements < cur_element_count)
644
+ max_elements = max_elements_;
645
+ max_elements_ = max_elements;
646
+ readBinaryPOD(input, size_data_per_element_);
647
+ readBinaryPOD(input, label_offset_);
648
+ readBinaryPOD(input, offsetData_);
649
+ readBinaryPOD(input, maxlevel_);
650
+ readBinaryPOD(input, enterpoint_node_);
651
+
652
+ readBinaryPOD(input, maxM_);
653
+ readBinaryPOD(input, maxM0_);
654
+ readBinaryPOD(input, M_);
655
+ readBinaryPOD(input, mult_);
656
+ readBinaryPOD(input, ef_construction_);
657
+
658
+
659
+ data_size_ = s->get_data_size();
660
+ fstdistfunc_ = s->get_dist_func();
661
+ dist_func_param_ = s->get_dist_func_param();
662
+
663
+ auto pos=input.tellg();
664
+
665
+
666
+ /// Optional - check if index is ok:
667
+
668
+ input.seekg(cur_element_count * size_data_per_element_,input.cur);
669
+ for (size_t i = 0; i < cur_element_count; i++) {
670
+ if(input.tellg() < 0 || input.tellg()>=total_filesize){
671
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
672
+ }
673
+
674
+ unsigned int linkListSize;
675
+ readBinaryPOD(input, linkListSize);
676
+ if (linkListSize != 0) {
677
+ input.seekg(linkListSize,input.cur);
678
+ }
679
+ }
680
+
681
+ // throw exception if it either corrupted or old index
682
+ if(input.tellg()!=total_filesize)
683
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
684
+
685
+ input.clear();
686
+
687
+ /// Optional check end
688
+
689
+ input.seekg(pos,input.beg);
690
+
691
+
692
+ data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
693
+ if (data_level0_memory_ == nullptr)
694
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
695
+ input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
696
+
697
+
698
+
699
+
700
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
701
+
702
+
703
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
704
+ std::vector<std::mutex>(max_elements).swap(link_list_locks_);
705
+ std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);
706
+
707
+
708
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
709
+
710
+
711
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
712
+ if (linkLists_ == nullptr)
713
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
714
+ element_levels_ = std::vector<int>(max_elements);
715
+ revSize_ = 1.0 / mult_;
716
+ ef_ = 10;
717
+ for (size_t i = 0; i < cur_element_count; i++) {
718
+ label_lookup_[getExternalLabel(i)]=i;
719
+ unsigned int linkListSize;
720
+ readBinaryPOD(input, linkListSize);
721
+ if (linkListSize == 0) {
722
+ element_levels_[i] = 0;
723
+
724
+ linkLists_[i] = nullptr;
725
+ } else {
726
+ element_levels_[i] = linkListSize / size_links_per_element_;
727
+ linkLists_[i] = (char *) malloc(linkListSize);
728
+ if (linkLists_[i] == nullptr)
729
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
730
+ input.read(linkLists_[i], linkListSize);
731
+ }
732
+ }
733
+
734
+ has_deletions_=false;
735
+
736
+ for (size_t i = 0; i < cur_element_count; i++) {
737
+ if(isMarkedDeleted(i))
738
+ has_deletions_=true;
739
+ }
740
+
741
+ input.close();
742
+
743
+ return;
744
+ }
745
+
746
+ template<typename data_t>
747
+ std::vector<data_t> getDataByLabel(labeltype label)
748
+ {
749
+ tableint label_c;
750
+ auto search = label_lookup_.find(label);
751
+ if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
752
+ throw std::runtime_error("Label not found");
753
+ }
754
+ label_c = search->second;
755
+
756
+ char* data_ptrv = getDataByInternalId(label_c);
757
+ size_t dim = *((size_t *) dist_func_param_);
758
+ std::vector<data_t> data;
759
+ data_t* data_ptr = (data_t*) data_ptrv;
760
+ for (int i = 0; i < dim; i++) {
761
+ data.push_back(*data_ptr);
762
+ data_ptr += 1;
763
+ }
764
+ return data;
765
+ }
766
+
767
+ static const unsigned char DELETE_MARK = 0x01;
768
+ // static const unsigned char REUSE_MARK = 0x10;
769
+ /**
770
+ * Marks an element with the given label deleted, does NOT really change the current graph.
771
+ * @param label
772
+ */
773
+ void markDelete(labeltype label)
774
+ {
775
+ has_deletions_=true;
776
+ auto search = label_lookup_.find(label);
777
+ if (search == label_lookup_.end()) {
778
+ throw std::runtime_error("Label not found");
779
+ }
780
+ markDeletedInternal(search->second);
781
+ }
782
+
783
+ /**
784
+ * Uses the first 8 bits of the memory for the linked list to store the mark,
785
+ * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
786
+ * @param internalId
787
+ */
788
+ void markDeletedInternal(tableint internalId) {
789
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
790
+ *ll_cur |= DELETE_MARK;
791
+ }
792
+
793
+ /**
794
+ * Remove the deleted mark of the node.
795
+ * @param internalId
796
+ */
797
+ void unmarkDeletedInternal(tableint internalId) {
798
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
799
+ *ll_cur &= ~DELETE_MARK;
800
+ }
801
+
802
+ /**
803
+ * Checks the first 8 bits of the memory to see if the element is marked deleted.
804
+ * @param internalId
805
+ * @return
806
+ */
807
+ bool isMarkedDeleted(tableint internalId) const {
808
+ unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
809
+ return *ll_cur & DELETE_MARK;
810
+ }
811
+
812
+ unsigned short int getListCount(linklistsizeint * ptr) const {
813
+ return *((unsigned short int *)ptr);
814
+ }
815
+
816
+ void setListCount(linklistsizeint * ptr, unsigned short int size) const {
817
+ *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
818
+ }
819
+
820
+ void addPoint(const void *data_point, labeltype label) {
821
+ addPoint(data_point, label,-1);
822
+ }
823
+
824
+ void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
825
+ // update the feature vector associated with existing point with new vector
826
+ memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
827
+
828
+ int maxLevelCopy = maxlevel_;
829
+ tableint entryPointCopy = enterpoint_node_;
830
+ // If point to be updated is entry point and graph just contains single element then just return.
831
+ if (entryPointCopy == internalId && cur_element_count == 1)
832
+ return;
833
+
834
+ int elemLevel = element_levels_[internalId];
835
+ std::uniform_real_distribution<float> distribution(0.0, 1.0);
836
+ for (int layer = 0; layer <= elemLevel; layer++) {
837
+ std::unordered_set<tableint> sCand;
838
+ std::unordered_set<tableint> sNeigh;
839
+ std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
840
+ if (listOneHop.size() == 0)
841
+ continue;
842
+
843
+ sCand.insert(internalId);
844
+
845
+ for (auto&& elOneHop : listOneHop) {
846
+ sCand.insert(elOneHop);
847
+
848
+ if (distribution(update_probability_generator_) > updateNeighborProbability)
849
+ continue;
850
+
851
+ sNeigh.insert(elOneHop);
852
+
853
+ std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
854
+ for (auto&& elTwoHop : listTwoHop) {
855
+ sCand.insert(elTwoHop);
856
+ }
857
+ }
858
+
859
+ for (auto&& neigh : sNeigh) {
860
+ // if (neigh == internalId)
861
+ // continue;
862
+
863
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
864
+ size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
865
+ size_t elementsToKeep = std::min(ef_construction_, size);
866
+ for (auto&& cand : sCand) {
867
+ if (cand == neigh)
868
+ continue;
869
+
870
+ dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
871
+ if (candidates.size() < elementsToKeep) {
872
+ candidates.emplace(distance, cand);
873
+ } else {
874
+ if (distance < candidates.top().first) {
875
+ candidates.pop();
876
+ candidates.emplace(distance, cand);
877
+ }
878
+ }
879
+ }
880
+
881
+ // Retrieve neighbours using heuristic and set connections.
882
+ getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
883
+
884
+ {
885
+ std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
886
+ linklistsizeint *ll_cur;
887
+ ll_cur = get_linklist_at_level(neigh, layer);
888
+ size_t candSize = candidates.size();
889
+ setListCount(ll_cur, candSize);
890
+ tableint *data = (tableint *) (ll_cur + 1);
891
+ for (size_t idx = 0; idx < candSize; idx++) {
892
+ data[idx] = candidates.top().second;
893
+ candidates.pop();
894
+ }
895
+ }
896
+ }
897
+ }
898
+
899
+ repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
900
+ };
901
+
902
+ void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
903
+ tableint currObj = entryPointInternalId;
904
+ if (dataPointLevel < maxLevel) {
905
+ dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
906
+ for (int level = maxLevel; level > dataPointLevel; level--) {
907
+ bool changed = true;
908
+ while (changed) {
909
+ changed = false;
910
+ unsigned int *data;
911
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
912
+ data = get_linklist_at_level(currObj,level);
913
+ int size = getListCount(data);
914
+ tableint *datal = (tableint *) (data + 1);
915
+ #ifdef USE_SSE
916
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
917
+ #endif
918
+ for (int i = 0; i < size; i++) {
919
+ #ifdef USE_SSE
920
+ _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
921
+ #endif
922
+ tableint cand = datal[i];
923
+ dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
924
+ if (d < curdist) {
925
+ curdist = d;
926
+ currObj = cand;
927
+ changed = true;
928
+ }
929
+ }
930
+ }
931
+ }
932
+ }
933
+
934
+ if (dataPointLevel > maxLevel)
935
+ throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
936
+
937
+ for (int level = dataPointLevel; level >= 0; level--) {
938
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
939
+ currObj, dataPoint, level);
940
+
941
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
942
+ while (topCandidates.size() > 0) {
943
+ if (topCandidates.top().second != dataPointInternalId)
944
+ filteredTopCandidates.push(topCandidates.top());
945
+
946
+ topCandidates.pop();
947
+ }
948
+
949
+ // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
950
+ // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
951
+ if (filteredTopCandidates.size() > 0) {
952
+ bool epDeleted = isMarkedDeleted(entryPointInternalId);
953
+ if (epDeleted) {
954
+ filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
955
+ if (filteredTopCandidates.size() > ef_construction_)
956
+ filteredTopCandidates.pop();
957
+ }
958
+
959
+ currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
960
+ }
961
+ }
962
+ }
963
+
964
+ std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
965
+ std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
966
+ unsigned int *data = get_linklist_at_level(internalId, level);
967
+ int size = getListCount(data);
968
+ std::vector<tableint> result(size);
969
+ tableint *ll = (tableint *) (data + 1);
970
+ memcpy(result.data(), ll,size * sizeof(tableint));
971
+ return result;
972
+ };
973
+
974
+ tableint addPoint(const void *data_point, labeltype label, int level) {
975
+
976
+ tableint cur_c = 0;
977
+ {
978
+ // Checking if the element with the same label already exists
979
+ // if so, updating it *instead* of creating a new element.
980
+ std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
981
+ auto search = label_lookup_.find(label);
982
+ if (search != label_lookup_.end()) {
983
+ tableint existingInternalId = search->second;
984
+ templock_curr.unlock();
985
+
986
+ std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
987
+
988
+ if (isMarkedDeleted(existingInternalId)) {
989
+ unmarkDeletedInternal(existingInternalId);
990
+ }
991
+ updatePoint(data_point, existingInternalId, 1.0);
992
+
993
+ return existingInternalId;
994
+ }
995
+
996
+ if (cur_element_count >= max_elements_) {
997
+ throw std::runtime_error("The number of elements exceeds the specified limit");
998
+ };
999
+
1000
+ cur_c = cur_element_count;
1001
+ cur_element_count++;
1002
+ label_lookup_[label] = cur_c;
1003
+ }
1004
+
1005
+ // Take update lock to prevent race conditions on an element with insertion/update at the same time.
1006
+ std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
1007
+ std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
1008
+ int curlevel = getRandomLevel(mult_);
1009
+ if (level > 0)
1010
+ curlevel = level;
1011
+
1012
+ element_levels_[cur_c] = curlevel;
1013
+
1014
+
1015
+ std::unique_lock <std::mutex> templock(global);
1016
+ int maxlevelcopy = maxlevel_;
1017
+ if (curlevel <= maxlevelcopy)
1018
+ templock.unlock();
1019
+ tableint currObj = enterpoint_node_;
1020
+ tableint enterpoint_copy = enterpoint_node_;
1021
+
1022
+
1023
+ memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1024
+
1025
+ // Initialisation of the data and label
1026
+ memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1027
+ memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1028
+
1029
+
1030
+ if (curlevel) {
1031
+ linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1032
+ if (linkLists_[cur_c] == nullptr)
1033
+ throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1034
+ memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1035
+ }
1036
+
1037
+ if ((signed)currObj != -1) {
1038
+
1039
+ if (curlevel < maxlevelcopy) {
1040
+
1041
+ dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1042
+ for (int level = maxlevelcopy; level > curlevel; level--) {
1043
+
1044
+
1045
+ bool changed = true;
1046
+ while (changed) {
1047
+ changed = false;
1048
+ unsigned int *data;
1049
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1050
+ data = get_linklist(currObj,level);
1051
+ int size = getListCount(data);
1052
+
1053
+ tableint *datal = (tableint *) (data + 1);
1054
+ for (int i = 0; i < size; i++) {
1055
+ tableint cand = datal[i];
1056
+ if (cand < 0 || cand > max_elements_)
1057
+ throw std::runtime_error("cand error");
1058
+ dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1059
+ if (d < curdist) {
1060
+ curdist = d;
1061
+ currObj = cand;
1062
+ changed = true;
1063
+ }
1064
+ }
1065
+ }
1066
+ }
1067
+ }
1068
+
1069
+ bool epDeleted = isMarkedDeleted(enterpoint_copy);
1070
+ for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1071
+ if (level > maxlevelcopy || level < 0) // possible?
1072
+ throw std::runtime_error("Level error");
1073
+
1074
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
1075
+ currObj, data_point, level);
1076
+ if (epDeleted) {
1077
+ top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1078
+ if (top_candidates.size() > ef_construction_)
1079
+ top_candidates.pop();
1080
+ }
1081
+ currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1082
+ }
1083
+
1084
+
1085
+ } else {
1086
+ // Do nothing for the first element
1087
+ enterpoint_node_ = 0;
1088
+ maxlevel_ = curlevel;
1089
+
1090
+ }
1091
+
1092
+ //Releasing lock for the maximum level
1093
+ if (curlevel > maxlevelcopy) {
1094
+ enterpoint_node_ = cur_c;
1095
+ maxlevel_ = curlevel;
1096
+ }
1097
+ return cur_c;
1098
+ };
1099
+
1100
+ std::priority_queue<std::pair<dist_t, labeltype >>
1101
+ searchKnn(const void *query_data, size_t k) const {
1102
+ std::priority_queue<std::pair<dist_t, labeltype >> result;
1103
+ if (cur_element_count == 0) return result;
1104
+
1105
+ tableint currObj = enterpoint_node_;
1106
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1107
+
1108
+ for (int level = maxlevel_; level > 0; level--) {
1109
+ bool changed = true;
1110
+ while (changed) {
1111
+ changed = false;
1112
+ unsigned int *data;
1113
+
1114
+ data = (unsigned int *) get_linklist(currObj, level);
1115
+ int size = getListCount(data);
1116
+ metric_hops++;
1117
+ metric_distance_computations+=size;
1118
+
1119
+ tableint *datal = (tableint *) (data + 1);
1120
+ for (int i = 0; i < size; i++) {
1121
+ tableint cand = datal[i];
1122
+ if (cand < 0 || cand > max_elements_)
1123
+ throw std::runtime_error("cand error");
1124
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1125
+
1126
+ if (d < curdist) {
1127
+ curdist = d;
1128
+ currObj = cand;
1129
+ changed = true;
1130
+ }
1131
+ }
1132
+ }
1133
+ }
1134
+
1135
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
1136
+ if (has_deletions_) {
1137
+ top_candidates=searchBaseLayerST<true,true>(
1138
+ currObj, query_data, std::max(ef_, k));
1139
+ }
1140
+ else{
1141
+ top_candidates=searchBaseLayerST<false,true>(
1142
+ currObj, query_data, std::max(ef_, k));
1143
+ }
1144
+
1145
+ while (top_candidates.size() > k) {
1146
+ top_candidates.pop();
1147
+ }
1148
+ while (top_candidates.size() > 0) {
1149
+ std::pair<dist_t, tableint> rez = top_candidates.top();
1150
+ result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
1151
+ top_candidates.pop();
1152
+ }
1153
+ return result;
1154
+ };
1155
+
1156
+ void checkIntegrity(){
1157
+ int connections_checked=0;
1158
+ std::vector <int > inbound_connections_num(cur_element_count,0);
1159
+ for(int i = 0;i < cur_element_count; i++){
1160
+ for(int l = 0;l <= element_levels_[i]; l++){
1161
+ linklistsizeint *ll_cur = get_linklist_at_level(i,l);
1162
+ int size = getListCount(ll_cur);
1163
+ tableint *data = (tableint *) (ll_cur + 1);
1164
+ std::unordered_set<tableint> s;
1165
+ for (int j=0; j<size; j++){
1166
+ assert(data[j] > 0);
1167
+ assert(data[j] < cur_element_count);
1168
+ assert (data[j] != i);
1169
+ inbound_connections_num[data[j]]++;
1170
+ s.insert(data[j]);
1171
+ connections_checked++;
1172
+
1173
+ }
1174
+ assert(s.size() == size);
1175
+ }
1176
+ }
1177
+ if(cur_element_count > 1){
1178
+ int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
1179
+ for(int i=0; i < cur_element_count; i++){
1180
+ assert(inbound_connections_num[i] > 0);
1181
+ min1=std::min(inbound_connections_num[i],min1);
1182
+ max1=std::max(inbound_connections_num[i],max1);
1183
+ }
1184
+ std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1185
+ }
1186
+ std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1187
+
1188
+ }
1189
+
1190
+ };
1191
+
1192
+ }