hnswlib 0.6.1 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/ext/hnswlib/hnswlibext.cpp +2 -3
- data/ext/hnswlib/hnswlibext.hpp +202 -62
- data/ext/hnswlib/src/bruteforce.h +142 -131
- data/ext/hnswlib/src/hnswalg.h +1028 -964
- data/ext/hnswlib/src/hnswlib.h +74 -66
- data/ext/hnswlib/src/space_ip.h +299 -299
- data/ext/hnswlib/src/space_l2.h +268 -273
- data/ext/hnswlib/src/visited_list_pool.h +54 -55
- data/lib/hnswlib/version.rb +2 -2
- data/lib/hnswlib.rb +17 -10
- data/sig/hnswlib.rbs +6 -6
- metadata +4 -3
@@ -3,155 +3,166 @@
|
|
3
3
|
#include <fstream>
|
4
4
|
#include <mutex>
|
5
5
|
#include <algorithm>
|
6
|
+
#include <assert.h>
|
6
7
|
|
7
8
|
namespace hnswlib {
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
cur_element_count
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
9
|
+
template<typename dist_t>
|
10
|
+
class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
11
|
+
public:
|
12
|
+
char *data_;
|
13
|
+
size_t maxelements_;
|
14
|
+
size_t cur_element_count;
|
15
|
+
size_t size_per_element_;
|
16
|
+
|
17
|
+
size_t data_size_;
|
18
|
+
DISTFUNC <dist_t> fstdistfunc_;
|
19
|
+
void *dist_func_param_;
|
20
|
+
std::mutex index_lock;
|
21
|
+
|
22
|
+
std::unordered_map<labeltype, size_t > dict_external_to_internal;
|
23
|
+
|
24
|
+
BruteforceSearch() : data_(nullptr) { }
|
25
|
+
|
26
|
+
BruteforceSearch(SpaceInterface <dist_t> *s)
|
27
|
+
: data_(nullptr),
|
28
|
+
maxelements_(0),
|
29
|
+
cur_element_count(0),
|
30
|
+
size_per_element_(0),
|
31
|
+
data_size_(0),
|
32
|
+
dist_func_param_(nullptr) {
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
|
37
|
+
: data_(nullptr),
|
38
|
+
maxelements_(0),
|
39
|
+
cur_element_count(0),
|
40
|
+
size_per_element_(0),
|
41
|
+
data_size_(0),
|
42
|
+
dist_func_param_(nullptr) {
|
43
|
+
loadIndex(location, s);
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
|
48
|
+
maxelements_ = maxElements;
|
49
|
+
data_size_ = s->get_data_size();
|
50
|
+
fstdistfunc_ = s->get_dist_func();
|
51
|
+
dist_func_param_ = s->get_dist_func_param();
|
52
|
+
size_per_element_ = data_size_ + sizeof(labeltype);
|
53
|
+
data_ = (char *) malloc(maxElements * size_per_element_);
|
54
|
+
if (data_ == nullptr)
|
55
|
+
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
|
56
|
+
cur_element_count = 0;
|
57
|
+
}
|
58
|
+
|
59
|
+
|
60
|
+
~BruteforceSearch() {
|
61
|
+
free(data_);
|
62
|
+
}
|
63
|
+
|
64
|
+
|
65
|
+
void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
|
66
|
+
int idx;
|
67
|
+
{
|
68
|
+
std::unique_lock<std::mutex> lock(index_lock);
|
69
|
+
|
70
|
+
auto search = dict_external_to_internal.find(label);
|
71
|
+
if (search != dict_external_to_internal.end()) {
|
72
|
+
idx = search->second;
|
73
|
+
} else {
|
74
|
+
if (cur_element_count >= maxelements_) {
|
75
|
+
throw std::runtime_error("The number of elements exceeds the specified limit\n");
|
66
76
|
}
|
77
|
+
idx = cur_element_count;
|
78
|
+
dict_external_to_internal[label] = idx;
|
79
|
+
cur_element_count++;
|
67
80
|
}
|
68
|
-
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
|
69
|
-
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
};
|
75
|
-
|
76
|
-
void removePoint(labeltype cur_external) {
|
77
|
-
size_t cur_c=dict_external_to_internal[cur_external];
|
78
|
-
|
79
|
-
dict_external_to_internal.erase(cur_external);
|
80
|
-
|
81
|
-
labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
82
|
-
dict_external_to_internal[label]=cur_c;
|
83
|
-
memcpy(data_ + size_per_element_ * cur_c,
|
84
|
-
data_ + size_per_element_ * (cur_element_count-1),
|
85
|
-
data_size_+sizeof(labeltype));
|
86
|
-
cur_element_count--;
|
87
|
-
|
88
81
|
}
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
82
|
+
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
|
83
|
+
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
|
84
|
+
}
|
85
|
+
|
86
|
+
|
87
|
+
void removePoint(labeltype cur_external) {
|
88
|
+
size_t cur_c = dict_external_to_internal[cur_external];
|
89
|
+
|
90
|
+
dict_external_to_internal.erase(cur_external);
|
91
|
+
|
92
|
+
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
93
|
+
dict_external_to_internal[label] = cur_c;
|
94
|
+
memcpy(data_ + size_per_element_ * cur_c,
|
95
|
+
data_ + size_per_element_ * (cur_element_count-1),
|
96
|
+
data_size_+sizeof(labeltype));
|
97
|
+
cur_element_count--;
|
98
|
+
}
|
99
|
+
|
100
|
+
|
101
|
+
std::priority_queue<std::pair<dist_t, labeltype >>
|
102
|
+
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
|
103
|
+
assert(k <= cur_element_count);
|
104
|
+
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
|
105
|
+
if (cur_element_count == 0) return topResults;
|
106
|
+
for (int i = 0; i < k; i++) {
|
107
|
+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
108
|
+
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
|
109
|
+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
110
|
+
topResults.push(std::pair<dist_t, labeltype>(dist, label));
|
99
111
|
}
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
lastdist = topResults.top().first;
|
112
|
+
}
|
113
|
+
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
|
114
|
+
for (int i = k; i < cur_element_count; i++) {
|
115
|
+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
116
|
+
if (dist <= lastdist) {
|
117
|
+
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
|
118
|
+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
119
|
+
topResults.push(std::pair<dist_t, labeltype>(dist, label));
|
109
120
|
}
|
121
|
+
if (topResults.size() > k)
|
122
|
+
topResults.pop();
|
110
123
|
|
124
|
+
if (!topResults.empty()) {
|
125
|
+
lastdist = topResults.top().first;
|
126
|
+
}
|
111
127
|
}
|
112
|
-
return topResults;
|
113
|
-
};
|
114
|
-
|
115
|
-
void saveIndex(const std::string &location) {
|
116
|
-
std::ofstream output(location, std::ios::binary);
|
117
|
-
std::streampos position;
|
118
|
-
|
119
|
-
writeBinaryPOD(output, maxelements_);
|
120
|
-
writeBinaryPOD(output, size_per_element_);
|
121
|
-
writeBinaryPOD(output, cur_element_count);
|
122
|
-
|
123
|
-
output.write(data_, maxelements_ * size_per_element_);
|
124
|
-
|
125
|
-
output.close();
|
126
128
|
}
|
129
|
+
return topResults;
|
130
|
+
}
|
127
131
|
|
128
|
-
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
|
129
132
|
|
133
|
+
void saveIndex(const std::string &location) {
|
134
|
+
std::ofstream output(location, std::ios::binary);
|
135
|
+
std::streampos position;
|
130
136
|
|
131
|
-
|
137
|
+
writeBinaryPOD(output, maxelements_);
|
138
|
+
writeBinaryPOD(output, size_per_element_);
|
139
|
+
writeBinaryPOD(output, cur_element_count);
|
132
140
|
|
133
|
-
|
134
|
-
throw std::runtime_error("Cannot open file");
|
141
|
+
output.write(data_, maxelements_ * size_per_element_);
|
135
142
|
|
136
|
-
|
143
|
+
output.close();
|
144
|
+
}
|
137
145
|
|
138
|
-
readBinaryPOD(input, maxelements_);
|
139
|
-
readBinaryPOD(input, size_per_element_);
|
140
|
-
readBinaryPOD(input, cur_element_count);
|
141
146
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
size_per_element_ = data_size_ + sizeof(labeltype);
|
146
|
-
data_ = (char *) malloc(maxelements_ * size_per_element_);
|
147
|
-
if (data_ == nullptr)
|
148
|
-
std::runtime_error("Not enough memory: loadIndex failed to allocate data");
|
147
|
+
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
|
148
|
+
std::ifstream input(location, std::ios::binary);
|
149
|
+
std::streampos position;
|
149
150
|
|
150
|
-
|
151
|
+
readBinaryPOD(input, maxelements_);
|
152
|
+
readBinaryPOD(input, size_per_element_);
|
153
|
+
readBinaryPOD(input, cur_element_count);
|
151
154
|
|
152
|
-
|
155
|
+
data_size_ = s->get_data_size();
|
156
|
+
fstdistfunc_ = s->get_dist_func();
|
157
|
+
dist_func_param_ = s->get_dist_func_param();
|
158
|
+
size_per_element_ = data_size_ + sizeof(labeltype);
|
159
|
+
data_ = (char *) malloc(maxelements_ * size_per_element_);
|
160
|
+
if (data_ == nullptr)
|
161
|
+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
|
153
162
|
|
154
|
-
|
163
|
+
input.read(data_, maxelements_ * size_per_element_);
|
155
164
|
|
156
|
-
|
157
|
-
}
|
165
|
+
input.close();
|
166
|
+
}
|
167
|
+
};
|
168
|
+
} // namespace hnswlib
|