umappp 0.1.5 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,150 +3,165 @@
3
3
  #include <fstream>
4
4
  #include <mutex>
5
5
  #include <algorithm>
6
+ #include <assert.h>
6
7
 
7
8
  namespace hnswlib {
8
- template<typename dist_t>
9
- class BruteforceSearch : public AlgorithmInterface<dist_t> {
10
- public:
11
- BruteforceSearch(SpaceInterface <dist_t> *s) {
12
-
13
- }
14
- BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location) {
15
- loadIndex(location, s);
16
- }
17
-
18
- BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
19
- maxelements_ = maxElements;
20
- data_size_ = s->get_data_size();
21
- fstdistfunc_ = s->get_dist_func();
22
- dist_func_param_ = s->get_dist_func_param();
23
- size_per_element_ = data_size_ + sizeof(labeltype);
24
- data_ = (char *) malloc(maxElements * size_per_element_);
25
- if (data_ == nullptr)
26
- std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
27
- cur_element_count = 0;
28
- }
29
-
30
- ~BruteforceSearch() {
31
- free(data_);
32
- }
33
-
34
- char *data_;
35
- size_t maxelements_;
36
- size_t cur_element_count;
37
- size_t size_per_element_;
38
-
39
- size_t data_size_;
40
- DISTFUNC <dist_t> fstdistfunc_;
41
- void *dist_func_param_;
42
- std::mutex index_lock;
43
-
44
- std::unordered_map<labeltype,size_t > dict_external_to_internal;
45
-
46
- void addPoint(const void *datapoint, labeltype label) {
47
-
48
- int idx;
49
- {
50
- std::unique_lock<std::mutex> lock(index_lock);
51
-
52
-
53
-
54
- auto search=dict_external_to_internal.find(label);
55
- if (search != dict_external_to_internal.end()) {
56
- idx=search->second;
57
- }
58
- else{
59
- if (cur_element_count >= maxelements_) {
60
- throw std::runtime_error("The number of elements exceeds the specified limit\n");
61
- }
62
- idx=cur_element_count;
63
- dict_external_to_internal[label] = idx;
64
- cur_element_count++;
9
+ template<typename dist_t>
10
+ class BruteforceSearch : public AlgorithmInterface<dist_t> {
11
+ public:
12
+ char *data_;
13
+ size_t maxelements_;
14
+ size_t cur_element_count;
15
+ size_t size_per_element_;
16
+
17
+ size_t data_size_;
18
+ DISTFUNC <dist_t> fstdistfunc_;
19
+ void *dist_func_param_;
20
+ std::mutex index_lock;
21
+
22
+ std::unordered_map<labeltype, size_t > dict_external_to_internal;
23
+
24
+
25
+ BruteforceSearch(SpaceInterface <dist_t> *s)
26
+ : data_(nullptr),
27
+ maxelements_(0),
28
+ cur_element_count(0),
29
+ size_per_element_(0),
30
+ data_size_(0),
31
+ dist_func_param_(nullptr) {
32
+ }
33
+
34
+
35
+ BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
36
+ : data_(nullptr),
37
+ maxelements_(0),
38
+ cur_element_count(0),
39
+ size_per_element_(0),
40
+ data_size_(0),
41
+ dist_func_param_(nullptr) {
42
+ loadIndex(location, s);
43
+ }
44
+
45
+
46
+ BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
47
+ maxelements_ = maxElements;
48
+ data_size_ = s->get_data_size();
49
+ fstdistfunc_ = s->get_dist_func();
50
+ dist_func_param_ = s->get_dist_func_param();
51
+ size_per_element_ = data_size_ + sizeof(labeltype);
52
+ data_ = (char *) malloc(maxElements * size_per_element_);
53
+ if (data_ == nullptr)
54
+ throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
55
+ cur_element_count = 0;
56
+ }
57
+
58
+
59
+ ~BruteforceSearch() {
60
+ free(data_);
61
+ }
62
+
63
+
64
+ void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
65
+ int idx;
66
+ {
67
+ std::unique_lock<std::mutex> lock(index_lock);
68
+
69
+ auto search = dict_external_to_internal.find(label);
70
+ if (search != dict_external_to_internal.end()) {
71
+ idx = search->second;
72
+ } else {
73
+ if (cur_element_count >= maxelements_) {
74
+ throw std::runtime_error("The number of elements exceeds the specified limit\n");
65
75
  }
76
+ idx = cur_element_count;
77
+ dict_external_to_internal[label] = idx;
78
+ cur_element_count++;
66
79
  }
67
- memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
68
- memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
69
-
70
-
71
-
72
-
73
- };
74
-
75
- void removePoint(labeltype cur_external) {
76
- size_t cur_c=dict_external_to_internal[cur_external];
77
-
78
- dict_external_to_internal.erase(cur_external);
79
-
80
- labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
81
- dict_external_to_internal[label]=cur_c;
82
- memcpy(data_ + size_per_element_ * cur_c,
83
- data_ + size_per_element_ * (cur_element_count-1),
84
- data_size_+sizeof(labeltype));
85
- cur_element_count--;
86
-
87
80
  }
88
-
89
-
90
- std::priority_queue<std::pair<dist_t, labeltype >>
91
- searchKnn(const void *query_data, size_t k) const {
92
- std::priority_queue<std::pair<dist_t, labeltype >> topResults;
93
- if (cur_element_count == 0) return topResults;
94
- for (int i = 0; i < k; i++) {
95
- dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
96
- topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
97
- data_size_))));
81
+ memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
82
+ memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
83
+ }
84
+
85
+
86
+ void removePoint(labeltype cur_external) {
87
+ size_t cur_c = dict_external_to_internal[cur_external];
88
+
89
+ dict_external_to_internal.erase(cur_external);
90
+
91
+ labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
92
+ dict_external_to_internal[label] = cur_c;
93
+ memcpy(data_ + size_per_element_ * cur_c,
94
+ data_ + size_per_element_ * (cur_element_count-1),
95
+ data_size_+sizeof(labeltype));
96
+ cur_element_count--;
97
+ }
98
+
99
+
100
+ std::priority_queue<std::pair<dist_t, labeltype >>
101
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
102
+ assert(k <= cur_element_count);
103
+ std::priority_queue<std::pair<dist_t, labeltype >> topResults;
104
+ if (cur_element_count == 0) return topResults;
105
+ for (int i = 0; i < k; i++) {
106
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
107
+ labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
108
+ if ((!isIdAllowed) || (*isIdAllowed)(label)) {
109
+ topResults.push(std::pair<dist_t, labeltype>(dist, label));
98
110
  }
99
- dist_t lastdist = topResults.top().first;
100
- for (int i = k; i < cur_element_count; i++) {
101
- dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
102
- if (dist <= lastdist) {
103
- topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
104
- data_size_))));
105
- if (topResults.size() > k)
106
- topResults.pop();
107
- lastdist = topResults.top().first;
111
+ }
112
+ dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
113
+ for (int i = k; i < cur_element_count; i++) {
114
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
115
+ if (dist <= lastdist) {
116
+ labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
117
+ if ((!isIdAllowed) || (*isIdAllowed)(label)) {
118
+ topResults.push(std::pair<dist_t, labeltype>(dist, label));
108
119
  }
120
+ if (topResults.size() > k)
121
+ topResults.pop();
109
122
 
123
+ if (!topResults.empty()) {
124
+ lastdist = topResults.top().first;
125
+ }
110
126
  }
111
- return topResults;
112
- };
113
-
114
- void saveIndex(const std::string &location) {
115
- std::ofstream output(location, std::ios::binary);
116
- std::streampos position;
117
-
118
- writeBinaryPOD(output, maxelements_);
119
- writeBinaryPOD(output, size_per_element_);
120
- writeBinaryPOD(output, cur_element_count);
127
+ }
128
+ return topResults;
129
+ }
121
130
 
122
- output.write(data_, maxelements_ * size_per_element_);
123
131
 
124
- output.close();
125
- }
132
+ void saveIndex(const std::string &location) {
133
+ std::ofstream output(location, std::ios::binary);
134
+ std::streampos position;
126
135
 
127
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
136
+ writeBinaryPOD(output, maxelements_);
137
+ writeBinaryPOD(output, size_per_element_);
138
+ writeBinaryPOD(output, cur_element_count);
128
139
 
140
+ output.write(data_, maxelements_ * size_per_element_);
129
141
 
130
- std::ifstream input(location, std::ios::binary);
131
- std::streampos position;
142
+ output.close();
143
+ }
132
144
 
133
- readBinaryPOD(input, maxelements_);
134
- readBinaryPOD(input, size_per_element_);
135
- readBinaryPOD(input, cur_element_count);
136
145
 
137
- data_size_ = s->get_data_size();
138
- fstdistfunc_ = s->get_dist_func();
139
- dist_func_param_ = s->get_dist_func_param();
140
- size_per_element_ = data_size_ + sizeof(labeltype);
141
- data_ = (char *) malloc(maxelements_ * size_per_element_);
142
- if (data_ == nullptr)
143
- std::runtime_error("Not enough memory: loadIndex failed to allocate data");
146
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
147
+ std::ifstream input(location, std::ios::binary);
148
+ std::streampos position;
144
149
 
145
- input.read(data_, maxelements_ * size_per_element_);
150
+ readBinaryPOD(input, maxelements_);
151
+ readBinaryPOD(input, size_per_element_);
152
+ readBinaryPOD(input, cur_element_count);
146
153
 
147
- input.close();
154
+ data_size_ = s->get_data_size();
155
+ fstdistfunc_ = s->get_dist_func();
156
+ dist_func_param_ = s->get_dist_func_param();
157
+ size_per_element_ = data_size_ + sizeof(labeltype);
158
+ data_ = (char *) malloc(maxelements_ * size_per_element_);
159
+ if (data_ == nullptr)
160
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
148
161
 
149
- }
162
+ input.read(data_, maxelements_ * size_per_element_);
150
163
 
151
- };
152
- }
164
+ input.close();
165
+ }
166
+ };
167
+ } // namespace hnswlib