hnswlib 0.6.2 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,155 +3,166 @@
3
3
  #include <fstream>
4
4
  #include <mutex>
5
5
  #include <algorithm>
6
+ #include <assert.h>
6
7
 
7
8
  namespace hnswlib {
8
- template<typename dist_t>
9
- class BruteforceSearch : public AlgorithmInterface<dist_t> {
10
- public:
11
- BruteforceSearch() : data_(nullptr) { }
12
- BruteforceSearch(SpaceInterface <dist_t> *s) {
13
-
14
- }
15
- BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location) {
16
- loadIndex(location, s);
17
- }
18
-
19
- BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
20
- maxelements_ = maxElements;
21
- data_size_ = s->get_data_size();
22
- fstdistfunc_ = s->get_dist_func();
23
- dist_func_param_ = s->get_dist_func_param();
24
- size_per_element_ = data_size_ + sizeof(labeltype);
25
- data_ = (char *) malloc(maxElements * size_per_element_);
26
- if (data_ == nullptr)
27
- std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
28
- cur_element_count = 0;
29
- }
30
-
31
- ~BruteforceSearch() {
32
- if (data_) free(data_);
33
- }
34
-
35
- char *data_;
36
- size_t maxelements_;
37
- size_t cur_element_count;
38
- size_t size_per_element_;
39
-
40
- size_t data_size_;
41
- DISTFUNC <dist_t> fstdistfunc_;
42
- void *dist_func_param_;
43
- std::mutex index_lock;
44
-
45
- std::unordered_map<labeltype,size_t > dict_external_to_internal;
46
-
47
- void addPoint(const void *datapoint, labeltype label) {
48
-
49
- int idx;
50
- {
51
- std::unique_lock<std::mutex> lock(index_lock);
52
-
53
-
54
-
55
- auto search=dict_external_to_internal.find(label);
56
- if (search != dict_external_to_internal.end()) {
57
- idx=search->second;
58
- }
59
- else{
60
- if (cur_element_count >= maxelements_) {
61
- throw std::runtime_error("The number of elements exceeds the specified limit\n");
62
- }
63
- idx=cur_element_count;
64
- dict_external_to_internal[label] = idx;
65
- cur_element_count++;
9
+ template<typename dist_t>
10
+ class BruteforceSearch : public AlgorithmInterface<dist_t> {
11
+ public:
12
+ char *data_;
13
+ size_t maxelements_;
14
+ size_t cur_element_count;
15
+ size_t size_per_element_;
16
+
17
+ size_t data_size_;
18
+ DISTFUNC <dist_t> fstdistfunc_;
19
+ void *dist_func_param_;
20
+ std::mutex index_lock;
21
+
22
+ std::unordered_map<labeltype, size_t > dict_external_to_internal;
23
+
24
+ BruteforceSearch() : data_(nullptr) { }
25
+
26
+ BruteforceSearch(SpaceInterface <dist_t> *s)
27
+ : data_(nullptr),
28
+ maxelements_(0),
29
+ cur_element_count(0),
30
+ size_per_element_(0),
31
+ data_size_(0),
32
+ dist_func_param_(nullptr) {
33
+ }
34
+
35
+
36
+ BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
37
+ : data_(nullptr),
38
+ maxelements_(0),
39
+ cur_element_count(0),
40
+ size_per_element_(0),
41
+ data_size_(0),
42
+ dist_func_param_(nullptr) {
43
+ loadIndex(location, s);
44
+ }
45
+
46
+
47
+ BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
48
+ maxelements_ = maxElements;
49
+ data_size_ = s->get_data_size();
50
+ fstdistfunc_ = s->get_dist_func();
51
+ dist_func_param_ = s->get_dist_func_param();
52
+ size_per_element_ = data_size_ + sizeof(labeltype);
53
+ data_ = (char *) malloc(maxElements * size_per_element_);
54
+ if (data_ == nullptr)
55
+ throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
56
+ cur_element_count = 0;
57
+ }
58
+
59
+
60
+ ~BruteforceSearch() {
61
+ free(data_);
62
+ }
63
+
64
+
65
+ void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
66
+ int idx;
67
+ {
68
+ std::unique_lock<std::mutex> lock(index_lock);
69
+
70
+ auto search = dict_external_to_internal.find(label);
71
+ if (search != dict_external_to_internal.end()) {
72
+ idx = search->second;
73
+ } else {
74
+ if (cur_element_count >= maxelements_) {
75
+ throw std::runtime_error("The number of elements exceeds the specified limit\n");
66
76
  }
77
+ idx = cur_element_count;
78
+ dict_external_to_internal[label] = idx;
79
+ cur_element_count++;
67
80
  }
68
- memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
69
- memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
70
-
71
-
72
-
73
-
74
- };
75
-
76
- void removePoint(labeltype cur_external) {
77
- size_t cur_c=dict_external_to_internal[cur_external];
78
-
79
- dict_external_to_internal.erase(cur_external);
80
-
81
- labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
82
- dict_external_to_internal[label]=cur_c;
83
- memcpy(data_ + size_per_element_ * cur_c,
84
- data_ + size_per_element_ * (cur_element_count-1),
85
- data_size_+sizeof(labeltype));
86
- cur_element_count--;
87
-
88
81
  }
89
-
90
-
91
- std::priority_queue<std::pair<dist_t, labeltype >>
92
- searchKnn(const void *query_data, size_t k) const {
93
- std::priority_queue<std::pair<dist_t, labeltype >> topResults;
94
- if (cur_element_count == 0) return topResults;
95
- for (int i = 0; i < k; i++) {
96
- dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
97
- topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
98
- data_size_))));
82
+ memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
83
+ memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
84
+ }
85
+
86
+
87
+ void removePoint(labeltype cur_external) {
88
+ size_t cur_c = dict_external_to_internal[cur_external];
89
+
90
+ dict_external_to_internal.erase(cur_external);
91
+
92
+ labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
93
+ dict_external_to_internal[label] = cur_c;
94
+ memcpy(data_ + size_per_element_ * cur_c,
95
+ data_ + size_per_element_ * (cur_element_count-1),
96
+ data_size_+sizeof(labeltype));
97
+ cur_element_count--;
98
+ }
99
+
100
+
101
+ std::priority_queue<std::pair<dist_t, labeltype >>
102
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
103
+ assert(k <= cur_element_count);
104
+ std::priority_queue<std::pair<dist_t, labeltype >> topResults;
105
+ if (cur_element_count == 0) return topResults;
106
+ for (int i = 0; i < k; i++) {
107
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
108
+ labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
109
+ if ((!isIdAllowed) || (*isIdAllowed)(label)) {
110
+ topResults.push(std::pair<dist_t, labeltype>(dist, label));
99
111
  }
100
- dist_t lastdist = topResults.top().first;
101
- for (int i = k; i < cur_element_count; i++) {
102
- dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
103
- if (dist <= lastdist) {
104
- topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
105
- data_size_))));
106
- if (topResults.size() > k)
107
- topResults.pop();
108
- lastdist = topResults.top().first;
112
+ }
113
+ dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
114
+ for (int i = k; i < cur_element_count; i++) {
115
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
116
+ if (dist <= lastdist) {
117
+ labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
118
+ if ((!isIdAllowed) || (*isIdAllowed)(label)) {
119
+ topResults.push(std::pair<dist_t, labeltype>(dist, label));
109
120
  }
121
+ if (topResults.size() > k)
122
+ topResults.pop();
110
123
 
124
+ if (!topResults.empty()) {
125
+ lastdist = topResults.top().first;
126
+ }
111
127
  }
112
- return topResults;
113
- };
114
-
115
- void saveIndex(const std::string &location) {
116
- std::ofstream output(location, std::ios::binary);
117
- std::streampos position;
118
-
119
- writeBinaryPOD(output, maxelements_);
120
- writeBinaryPOD(output, size_per_element_);
121
- writeBinaryPOD(output, cur_element_count);
122
-
123
- output.write(data_, maxelements_ * size_per_element_);
124
-
125
- output.close();
126
128
  }
129
+ return topResults;
130
+ }
127
131
 
128
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
129
132
 
133
+ void saveIndex(const std::string &location) {
134
+ std::ofstream output(location, std::ios::binary);
135
+ std::streampos position;
130
136
 
131
- std::ifstream input(location, std::ios::binary);
137
+ writeBinaryPOD(output, maxelements_);
138
+ writeBinaryPOD(output, size_per_element_);
139
+ writeBinaryPOD(output, cur_element_count);
132
140
 
133
- if (!input.is_open())
134
- throw std::runtime_error("Cannot open file");
141
+ output.write(data_, maxelements_ * size_per_element_);
135
142
 
136
- std::streampos position;
143
+ output.close();
144
+ }
137
145
 
138
- readBinaryPOD(input, maxelements_);
139
- readBinaryPOD(input, size_per_element_);
140
- readBinaryPOD(input, cur_element_count);
141
146
 
142
- data_size_ = s->get_data_size();
143
- fstdistfunc_ = s->get_dist_func();
144
- dist_func_param_ = s->get_dist_func_param();
145
- size_per_element_ = data_size_ + sizeof(labeltype);
146
- data_ = (char *) malloc(maxelements_ * size_per_element_);
147
- if (data_ == nullptr)
148
- std::runtime_error("Not enough memory: loadIndex failed to allocate data");
147
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
148
+ std::ifstream input(location, std::ios::binary);
149
+ std::streampos position;
149
150
 
150
- input.read(data_, maxelements_ * size_per_element_);
151
+ readBinaryPOD(input, maxelements_);
152
+ readBinaryPOD(input, size_per_element_);
153
+ readBinaryPOD(input, cur_element_count);
151
154
 
152
- input.close();
155
+ data_size_ = s->get_data_size();
156
+ fstdistfunc_ = s->get_dist_func();
157
+ dist_func_param_ = s->get_dist_func_param();
158
+ size_per_element_ = data_size_ + sizeof(labeltype);
159
+ data_ = (char *) malloc(maxelements_ * size_per_element_);
160
+ if (data_ == nullptr)
161
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
153
162
 
154
- }
163
+ input.read(data_, maxelements_ * size_per_element_);
155
164
 
156
- };
157
- }
165
+ input.close();
166
+ }
167
+ };
168
+ } // namespace hnswlib