umappp 0.1.5 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +11 -4
- data/ext/umappp/umappp.cpp +41 -43
- data/lib/umappp/version.rb +1 -1
- data/lib/umappp.rb +5 -4
- data/vendor/aarand/aarand.hpp +141 -28
- data/vendor/annoy/annoylib.h +1 -1
- data/vendor/hnswlib/bruteforce.h +142 -127
- data/vendor/hnswlib/hnswalg.h +1018 -939
- data/vendor/hnswlib/hnswlib.h +149 -58
- data/vendor/hnswlib/space_ip.h +322 -229
- data/vendor/hnswlib/space_l2.h +283 -240
- data/vendor/hnswlib/visited_list_pool.h +54 -55
- data/vendor/irlba/irlba.hpp +12 -27
- data/vendor/irlba/lanczos.hpp +30 -31
- data/vendor/irlba/parallel.hpp +37 -38
- data/vendor/irlba/utils.hpp +12 -23
- data/vendor/irlba/wrappers.hpp +239 -70
- data/vendor/kmeans/Details.hpp +1 -1
- data/vendor/kmeans/HartiganWong.hpp +28 -2
- data/vendor/kmeans/InitializeKmeansPP.hpp +29 -1
- data/vendor/kmeans/Kmeans.hpp +25 -2
- data/vendor/kmeans/Lloyd.hpp +29 -2
- data/vendor/kmeans/MiniBatch.hpp +48 -8
- data/vendor/knncolle/Annoy/Annoy.hpp +3 -0
- data/vendor/knncolle/Hnsw/Hnsw.hpp +3 -0
- data/vendor/knncolle/Kmknn/Kmknn.hpp +11 -1
- data/vendor/knncolle/utils/find_nearest_neighbors.hpp +8 -6
- data/vendor/umappp/Umap.hpp +85 -43
- data/vendor/umappp/optimize_layout.hpp +410 -133
- data/vendor/umappp/spectral_init.hpp +4 -1
- metadata +6 -6
data/vendor/hnswlib/bruteforce.h
CHANGED
@@ -3,150 +3,165 @@
|
|
3
3
|
#include <fstream>
|
4
4
|
#include <mutex>
|
5
5
|
#include <algorithm>
|
6
|
+
#include <assert.h>
|
6
7
|
|
7
8
|
namespace hnswlib {
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
cur_element_count
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
9
|
+
template<typename dist_t>
|
10
|
+
class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
11
|
+
public:
|
12
|
+
char *data_;
|
13
|
+
size_t maxelements_;
|
14
|
+
size_t cur_element_count;
|
15
|
+
size_t size_per_element_;
|
16
|
+
|
17
|
+
size_t data_size_;
|
18
|
+
DISTFUNC <dist_t> fstdistfunc_;
|
19
|
+
void *dist_func_param_;
|
20
|
+
std::mutex index_lock;
|
21
|
+
|
22
|
+
std::unordered_map<labeltype, size_t > dict_external_to_internal;
|
23
|
+
|
24
|
+
|
25
|
+
BruteforceSearch(SpaceInterface <dist_t> *s)
|
26
|
+
: data_(nullptr),
|
27
|
+
maxelements_(0),
|
28
|
+
cur_element_count(0),
|
29
|
+
size_per_element_(0),
|
30
|
+
data_size_(0),
|
31
|
+
dist_func_param_(nullptr) {
|
32
|
+
}
|
33
|
+
|
34
|
+
|
35
|
+
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
|
36
|
+
: data_(nullptr),
|
37
|
+
maxelements_(0),
|
38
|
+
cur_element_count(0),
|
39
|
+
size_per_element_(0),
|
40
|
+
data_size_(0),
|
41
|
+
dist_func_param_(nullptr) {
|
42
|
+
loadIndex(location, s);
|
43
|
+
}
|
44
|
+
|
45
|
+
|
46
|
+
BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
|
47
|
+
maxelements_ = maxElements;
|
48
|
+
data_size_ = s->get_data_size();
|
49
|
+
fstdistfunc_ = s->get_dist_func();
|
50
|
+
dist_func_param_ = s->get_dist_func_param();
|
51
|
+
size_per_element_ = data_size_ + sizeof(labeltype);
|
52
|
+
data_ = (char *) malloc(maxElements * size_per_element_);
|
53
|
+
if (data_ == nullptr)
|
54
|
+
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
|
55
|
+
cur_element_count = 0;
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
~BruteforceSearch() {
|
60
|
+
free(data_);
|
61
|
+
}
|
62
|
+
|
63
|
+
|
64
|
+
void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
|
65
|
+
int idx;
|
66
|
+
{
|
67
|
+
std::unique_lock<std::mutex> lock(index_lock);
|
68
|
+
|
69
|
+
auto search = dict_external_to_internal.find(label);
|
70
|
+
if (search != dict_external_to_internal.end()) {
|
71
|
+
idx = search->second;
|
72
|
+
} else {
|
73
|
+
if (cur_element_count >= maxelements_) {
|
74
|
+
throw std::runtime_error("The number of elements exceeds the specified limit\n");
|
65
75
|
}
|
76
|
+
idx = cur_element_count;
|
77
|
+
dict_external_to_internal[label] = idx;
|
78
|
+
cur_element_count++;
|
66
79
|
}
|
67
|
-
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
|
68
|
-
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
};
|
74
|
-
|
75
|
-
void removePoint(labeltype cur_external) {
|
76
|
-
size_t cur_c=dict_external_to_internal[cur_external];
|
77
|
-
|
78
|
-
dict_external_to_internal.erase(cur_external);
|
79
|
-
|
80
|
-
labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
81
|
-
dict_external_to_internal[label]=cur_c;
|
82
|
-
memcpy(data_ + size_per_element_ * cur_c,
|
83
|
-
data_ + size_per_element_ * (cur_element_count-1),
|
84
|
-
data_size_+sizeof(labeltype));
|
85
|
-
cur_element_count--;
|
86
|
-
|
87
80
|
}
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
81
|
+
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
|
82
|
+
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
|
83
|
+
}
|
84
|
+
|
85
|
+
|
86
|
+
void removePoint(labeltype cur_external) {
|
87
|
+
size_t cur_c = dict_external_to_internal[cur_external];
|
88
|
+
|
89
|
+
dict_external_to_internal.erase(cur_external);
|
90
|
+
|
91
|
+
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
92
|
+
dict_external_to_internal[label] = cur_c;
|
93
|
+
memcpy(data_ + size_per_element_ * cur_c,
|
94
|
+
data_ + size_per_element_ * (cur_element_count-1),
|
95
|
+
data_size_+sizeof(labeltype));
|
96
|
+
cur_element_count--;
|
97
|
+
}
|
98
|
+
|
99
|
+
|
100
|
+
std::priority_queue<std::pair<dist_t, labeltype >>
|
101
|
+
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
|
102
|
+
assert(k <= cur_element_count);
|
103
|
+
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
|
104
|
+
if (cur_element_count == 0) return topResults;
|
105
|
+
for (int i = 0; i < k; i++) {
|
106
|
+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
107
|
+
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
|
108
|
+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
109
|
+
topResults.push(std::pair<dist_t, labeltype>(dist, label));
|
98
110
|
}
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
lastdist = topResults.top().first;
|
111
|
+
}
|
112
|
+
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
|
113
|
+
for (int i = k; i < cur_element_count; i++) {
|
114
|
+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
115
|
+
if (dist <= lastdist) {
|
116
|
+
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
|
117
|
+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
118
|
+
topResults.push(std::pair<dist_t, labeltype>(dist, label));
|
108
119
|
}
|
120
|
+
if (topResults.size() > k)
|
121
|
+
topResults.pop();
|
109
122
|
|
123
|
+
if (!topResults.empty()) {
|
124
|
+
lastdist = topResults.top().first;
|
125
|
+
}
|
110
126
|
}
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
void saveIndex(const std::string &location) {
|
115
|
-
std::ofstream output(location, std::ios::binary);
|
116
|
-
std::streampos position;
|
117
|
-
|
118
|
-
writeBinaryPOD(output, maxelements_);
|
119
|
-
writeBinaryPOD(output, size_per_element_);
|
120
|
-
writeBinaryPOD(output, cur_element_count);
|
127
|
+
}
|
128
|
+
return topResults;
|
129
|
+
}
|
121
130
|
|
122
|
-
output.write(data_, maxelements_ * size_per_element_);
|
123
131
|
|
124
|
-
|
125
|
-
|
132
|
+
void saveIndex(const std::string &location) {
|
133
|
+
std::ofstream output(location, std::ios::binary);
|
134
|
+
std::streampos position;
|
126
135
|
|
127
|
-
|
136
|
+
writeBinaryPOD(output, maxelements_);
|
137
|
+
writeBinaryPOD(output, size_per_element_);
|
138
|
+
writeBinaryPOD(output, cur_element_count);
|
128
139
|
|
140
|
+
output.write(data_, maxelements_ * size_per_element_);
|
129
141
|
|
130
|
-
|
131
|
-
|
142
|
+
output.close();
|
143
|
+
}
|
132
144
|
|
133
|
-
readBinaryPOD(input, maxelements_);
|
134
|
-
readBinaryPOD(input, size_per_element_);
|
135
|
-
readBinaryPOD(input, cur_element_count);
|
136
145
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
size_per_element_ = data_size_ + sizeof(labeltype);
|
141
|
-
data_ = (char *) malloc(maxelements_ * size_per_element_);
|
142
|
-
if (data_ == nullptr)
|
143
|
-
std::runtime_error("Not enough memory: loadIndex failed to allocate data");
|
146
|
+
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
|
147
|
+
std::ifstream input(location, std::ios::binary);
|
148
|
+
std::streampos position;
|
144
149
|
|
145
|
-
|
150
|
+
readBinaryPOD(input, maxelements_);
|
151
|
+
readBinaryPOD(input, size_per_element_);
|
152
|
+
readBinaryPOD(input, cur_element_count);
|
146
153
|
|
147
|
-
|
154
|
+
data_size_ = s->get_data_size();
|
155
|
+
fstdistfunc_ = s->get_dist_func();
|
156
|
+
dist_func_param_ = s->get_dist_func_param();
|
157
|
+
size_per_element_ = data_size_ + sizeof(labeltype);
|
158
|
+
data_ = (char *) malloc(maxelements_ * size_per_element_);
|
159
|
+
if (data_ == nullptr)
|
160
|
+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
|
148
161
|
|
149
|
-
|
162
|
+
input.read(data_, maxelements_ * size_per_element_);
|
150
163
|
|
151
|
-
|
152
|
-
}
|
164
|
+
input.close();
|
165
|
+
}
|
166
|
+
};
|
167
|
+
} // namespace hnswlib
|