logosdb 0.7.8 → 0.7.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,324 @@
1
+ #pragma once
2
+ #include "hnswlib.h"
3
+
4
+ namespace hnswlib {
5
+
6
+ static float
7
+ L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
8
+ float *pVect1 = (float *) pVect1v;
9
+ float *pVect2 = (float *) pVect2v;
10
+ size_t qty = *((size_t *) qty_ptr);
11
+
12
+ float res = 0;
13
+ for (size_t i = 0; i < qty; i++) {
14
+ float t = *pVect1 - *pVect2;
15
+ pVect1++;
16
+ pVect2++;
17
+ res += t * t;
18
+ }
19
+ return (res);
20
+ }
21
+
22
+ #if defined(USE_AVX512)
23
+
24
+ // Favor using AVX512 if available.
25
+ static float
26
+ L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
27
+ float *pVect1 = (float *) pVect1v;
28
+ float *pVect2 = (float *) pVect2v;
29
+ size_t qty = *((size_t *) qty_ptr);
30
+ float PORTABLE_ALIGN64 TmpRes[16];
31
+ size_t qty16 = qty >> 4;
32
+
33
+ const float *pEnd1 = pVect1 + (qty16 << 4);
34
+
35
+ __m512 diff, v1, v2;
36
+ __m512 sum = _mm512_set1_ps(0);
37
+
38
+ while (pVect1 < pEnd1) {
39
+ v1 = _mm512_loadu_ps(pVect1);
40
+ pVect1 += 16;
41
+ v2 = _mm512_loadu_ps(pVect2);
42
+ pVect2 += 16;
43
+ diff = _mm512_sub_ps(v1, v2);
44
+ // sum = _mm512_fmadd_ps(diff, diff, sum);
45
+ sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
46
+ }
47
+
48
+ _mm512_store_ps(TmpRes, sum);
49
+ float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
50
+ TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
51
+ TmpRes[13] + TmpRes[14] + TmpRes[15];
52
+
53
+ return (res);
54
+ }
55
+ #endif
56
+
57
+ #if defined(USE_AVX)
58
+
59
+ // Favor using AVX if available.
60
+ static float
61
+ L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
62
+ float *pVect1 = (float *) pVect1v;
63
+ float *pVect2 = (float *) pVect2v;
64
+ size_t qty = *((size_t *) qty_ptr);
65
+ float PORTABLE_ALIGN32 TmpRes[8];
66
+ size_t qty16 = qty >> 4;
67
+
68
+ const float *pEnd1 = pVect1 + (qty16 << 4);
69
+
70
+ __m256 diff, v1, v2;
71
+ __m256 sum = _mm256_set1_ps(0);
72
+
73
+ while (pVect1 < pEnd1) {
74
+ v1 = _mm256_loadu_ps(pVect1);
75
+ pVect1 += 8;
76
+ v2 = _mm256_loadu_ps(pVect2);
77
+ pVect2 += 8;
78
+ diff = _mm256_sub_ps(v1, v2);
79
+ sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
80
+
81
+ v1 = _mm256_loadu_ps(pVect1);
82
+ pVect1 += 8;
83
+ v2 = _mm256_loadu_ps(pVect2);
84
+ pVect2 += 8;
85
+ diff = _mm256_sub_ps(v1, v2);
86
+ sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
87
+ }
88
+
89
+ _mm256_store_ps(TmpRes, sum);
90
+ return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
91
+ }
92
+
93
+ #endif
94
+
95
+ #if defined(USE_SSE)
96
+
97
+ static float
98
+ L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
99
+ float *pVect1 = (float *) pVect1v;
100
+ float *pVect2 = (float *) pVect2v;
101
+ size_t qty = *((size_t *) qty_ptr);
102
+ float PORTABLE_ALIGN32 TmpRes[8];
103
+ size_t qty16 = qty >> 4;
104
+
105
+ const float *pEnd1 = pVect1 + (qty16 << 4);
106
+
107
+ __m128 diff, v1, v2;
108
+ __m128 sum = _mm_set1_ps(0);
109
+
110
+ while (pVect1 < pEnd1) {
111
+ //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
112
+ v1 = _mm_loadu_ps(pVect1);
113
+ pVect1 += 4;
114
+ v2 = _mm_loadu_ps(pVect2);
115
+ pVect2 += 4;
116
+ diff = _mm_sub_ps(v1, v2);
117
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
118
+
119
+ v1 = _mm_loadu_ps(pVect1);
120
+ pVect1 += 4;
121
+ v2 = _mm_loadu_ps(pVect2);
122
+ pVect2 += 4;
123
+ diff = _mm_sub_ps(v1, v2);
124
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
125
+
126
+ v1 = _mm_loadu_ps(pVect1);
127
+ pVect1 += 4;
128
+ v2 = _mm_loadu_ps(pVect2);
129
+ pVect2 += 4;
130
+ diff = _mm_sub_ps(v1, v2);
131
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
132
+
133
+ v1 = _mm_loadu_ps(pVect1);
134
+ pVect1 += 4;
135
+ v2 = _mm_loadu_ps(pVect2);
136
+ pVect2 += 4;
137
+ diff = _mm_sub_ps(v1, v2);
138
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
139
+ }
140
+
141
+ _mm_store_ps(TmpRes, sum);
142
+ return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
143
+ }
144
+ #endif
145
+
146
+ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
147
+ static DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;
148
+
149
+ static float
150
+ L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
151
+ size_t qty = *((size_t *) qty_ptr);
152
+ size_t qty16 = qty >> 4 << 4;
153
+ float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
154
+ float *pVect1 = (float *) pVect1v + qty16;
155
+ float *pVect2 = (float *) pVect2v + qty16;
156
+
157
+ size_t qty_left = qty - qty16;
158
+ float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
159
+ return (res + res_tail);
160
+ }
161
+ #endif
162
+
163
+
164
+ #if defined(USE_SSE)
165
+ static float
166
+ L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
167
+ float PORTABLE_ALIGN32 TmpRes[8];
168
+ float *pVect1 = (float *) pVect1v;
169
+ float *pVect2 = (float *) pVect2v;
170
+ size_t qty = *((size_t *) qty_ptr);
171
+
172
+
173
+ size_t qty4 = qty >> 2;
174
+
175
+ const float *pEnd1 = pVect1 + (qty4 << 2);
176
+
177
+ __m128 diff, v1, v2;
178
+ __m128 sum = _mm_set1_ps(0);
179
+
180
+ while (pVect1 < pEnd1) {
181
+ v1 = _mm_loadu_ps(pVect1);
182
+ pVect1 += 4;
183
+ v2 = _mm_loadu_ps(pVect2);
184
+ pVect2 += 4;
185
+ diff = _mm_sub_ps(v1, v2);
186
+ sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
187
+ }
188
+ _mm_store_ps(TmpRes, sum);
189
+ return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
190
+ }
191
+
192
+ static float
193
+ L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
194
+ size_t qty = *((size_t *) qty_ptr);
195
+ size_t qty4 = qty >> 2 << 2;
196
+
197
+ float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
198
+ size_t qty_left = qty - qty4;
199
+
200
+ float *pVect1 = (float *) pVect1v + qty4;
201
+ float *pVect2 = (float *) pVect2v + qty4;
202
+ float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
203
+
204
+ return (res + res_tail);
205
+ }
206
+ #endif
207
+
208
+ class L2Space : public SpaceInterface<float> {
209
+ DISTFUNC<float> fstdistfunc_;
210
+ size_t data_size_;
211
+ size_t dim_;
212
+
213
+ public:
214
+ L2Space(size_t dim) {
215
+ fstdistfunc_ = L2Sqr;
216
+ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
217
+ #if defined(USE_AVX512)
218
+ if (AVX512Capable())
219
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
220
+ else if (AVXCapable())
221
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
222
+ #elif defined(USE_AVX)
223
+ if (AVXCapable())
224
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
225
+ #endif
226
+
227
+ if (dim % 16 == 0)
228
+ fstdistfunc_ = L2SqrSIMD16Ext;
229
+ else if (dim % 4 == 0)
230
+ fstdistfunc_ = L2SqrSIMD4Ext;
231
+ else if (dim > 16)
232
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
233
+ else if (dim > 4)
234
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
235
+ #endif
236
+ dim_ = dim;
237
+ data_size_ = dim * sizeof(float);
238
+ }
239
+
240
+ size_t get_data_size() {
241
+ return data_size_;
242
+ }
243
+
244
+ DISTFUNC<float> get_dist_func() {
245
+ return fstdistfunc_;
246
+ }
247
+
248
+ void *get_dist_func_param() {
249
+ return &dim_;
250
+ }
251
+
252
+ ~L2Space() {}
253
+ };
254
+
255
+ static int
256
+ L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
257
+ size_t qty = *((size_t *) qty_ptr);
258
+ int res = 0;
259
+ unsigned char *a = (unsigned char *) pVect1;
260
+ unsigned char *b = (unsigned char *) pVect2;
261
+
262
+ qty = qty >> 2;
263
+ for (size_t i = 0; i < qty; i++) {
264
+ res += ((*a) - (*b)) * ((*a) - (*b));
265
+ a++;
266
+ b++;
267
+ res += ((*a) - (*b)) * ((*a) - (*b));
268
+ a++;
269
+ b++;
270
+ res += ((*a) - (*b)) * ((*a) - (*b));
271
+ a++;
272
+ b++;
273
+ res += ((*a) - (*b)) * ((*a) - (*b));
274
+ a++;
275
+ b++;
276
+ }
277
+ return (res);
278
+ }
279
+
280
+ static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
281
+ size_t qty = *((size_t*)qty_ptr);
282
+ int res = 0;
283
+ unsigned char* a = (unsigned char*)pVect1;
284
+ unsigned char* b = (unsigned char*)pVect2;
285
+
286
+ for (size_t i = 0; i < qty; i++) {
287
+ res += ((*a) - (*b)) * ((*a) - (*b));
288
+ a++;
289
+ b++;
290
+ }
291
+ return (res);
292
+ }
293
+
294
+ class L2SpaceI : public SpaceInterface<int> {
295
+ DISTFUNC<int> fstdistfunc_;
296
+ size_t data_size_;
297
+ size_t dim_;
298
+
299
+ public:
300
+ L2SpaceI(size_t dim) {
301
+ if (dim % 4 == 0) {
302
+ fstdistfunc_ = L2SqrI4x;
303
+ } else {
304
+ fstdistfunc_ = L2SqrI;
305
+ }
306
+ dim_ = dim;
307
+ data_size_ = dim * sizeof(unsigned char);
308
+ }
309
+
310
+ size_t get_data_size() {
311
+ return data_size_;
312
+ }
313
+
314
+ DISTFUNC<int> get_dist_func() {
315
+ return fstdistfunc_;
316
+ }
317
+
318
+ void *get_dist_func_param() {
319
+ return &dim_;
320
+ }
321
+
322
+ ~L2SpaceI() {}
323
+ };
324
+ } // namespace hnswlib
@@ -0,0 +1,276 @@
1
+ #pragma once
2
+ #include "space_l2.h"
3
+ #include "space_ip.h"
4
+ #include <assert.h>
5
+ #include <unordered_map>
6
+
7
+ namespace hnswlib {
8
+
9
+ template<typename DOCIDTYPE>
10
+ class BaseMultiVectorSpace : public SpaceInterface<float> {
11
+ public:
12
+ virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0;
13
+
14
+ virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0;
15
+ };
16
+
17
+
18
+ template<typename DOCIDTYPE>
19
+ class MultiVectorL2Space : public BaseMultiVectorSpace<DOCIDTYPE> {
20
+ DISTFUNC<float> fstdistfunc_;
21
+ size_t data_size_;
22
+ size_t vector_size_;
23
+ size_t dim_;
24
+
25
+ public:
26
+ MultiVectorL2Space(size_t dim) {
27
+ fstdistfunc_ = L2Sqr;
28
+ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
29
+ #if defined(USE_AVX512)
30
+ if (AVX512Capable())
31
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
32
+ else if (AVXCapable())
33
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
34
+ #elif defined(USE_AVX)
35
+ if (AVXCapable())
36
+ L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
37
+ #endif
38
+
39
+ if (dim % 16 == 0)
40
+ fstdistfunc_ = L2SqrSIMD16Ext;
41
+ else if (dim % 4 == 0)
42
+ fstdistfunc_ = L2SqrSIMD4Ext;
43
+ else if (dim > 16)
44
+ fstdistfunc_ = L2SqrSIMD16ExtResiduals;
45
+ else if (dim > 4)
46
+ fstdistfunc_ = L2SqrSIMD4ExtResiduals;
47
+ #endif
48
+ dim_ = dim;
49
+ vector_size_ = dim * sizeof(float);
50
+ data_size_ = vector_size_ + sizeof(DOCIDTYPE);
51
+ }
52
+
53
+ size_t get_data_size() override {
54
+ return data_size_;
55
+ }
56
+
57
+ DISTFUNC<float> get_dist_func() override {
58
+ return fstdistfunc_;
59
+ }
60
+
61
+ void *get_dist_func_param() override {
62
+ return &dim_;
63
+ }
64
+
65
+ DOCIDTYPE get_doc_id(const void *datapoint) override {
66
+ return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
67
+ }
68
+
69
+ void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
70
+ *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
71
+ }
72
+
73
+ ~MultiVectorL2Space() {}
74
+ };
75
+
76
+
77
+ template<typename DOCIDTYPE>
78
+ class MultiVectorInnerProductSpace : public BaseMultiVectorSpace<DOCIDTYPE> {
79
+ DISTFUNC<float> fstdistfunc_;
80
+ size_t data_size_;
81
+ size_t vector_size_;
82
+ size_t dim_;
83
+
84
+ public:
85
+ MultiVectorInnerProductSpace(size_t dim) {
86
+ fstdistfunc_ = InnerProductDistance;
87
+ #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
88
+ #if defined(USE_AVX512)
89
+ if (AVX512Capable()) {
90
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
91
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
92
+ } else if (AVXCapable()) {
93
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
94
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
95
+ }
96
+ #elif defined(USE_AVX)
97
+ if (AVXCapable()) {
98
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
99
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
100
+ }
101
+ #endif
102
+ #if defined(USE_AVX)
103
+ if (AVXCapable()) {
104
+ InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
105
+ InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
106
+ }
107
+ #endif
108
+
109
+ if (dim % 16 == 0)
110
+ fstdistfunc_ = InnerProductDistanceSIMD16Ext;
111
+ else if (dim % 4 == 0)
112
+ fstdistfunc_ = InnerProductDistanceSIMD4Ext;
113
+ else if (dim > 16)
114
+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
115
+ else if (dim > 4)
116
+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
117
+ #endif
118
+ vector_size_ = dim * sizeof(float);
119
+ data_size_ = vector_size_ + sizeof(DOCIDTYPE);
120
+ }
121
+
122
+ size_t get_data_size() override {
123
+ return data_size_;
124
+ }
125
+
126
+ DISTFUNC<float> get_dist_func() override {
127
+ return fstdistfunc_;
128
+ }
129
+
130
+ void *get_dist_func_param() override {
131
+ return &dim_;
132
+ }
133
+
134
+ DOCIDTYPE get_doc_id(const void *datapoint) override {
135
+ return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
136
+ }
137
+
138
+ void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
139
+ *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
140
+ }
141
+
142
+ ~MultiVectorInnerProductSpace() {}
143
+ };
144
+
145
+
146
+ template<typename DOCIDTYPE, typename dist_t>
147
+ class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
148
+ size_t curr_num_docs_;
149
+ size_t num_docs_to_search_;
150
+ size_t ef_collection_;
151
+ std::unordered_map<DOCIDTYPE, size_t> doc_counter_;
152
+ std::priority_queue<std::pair<dist_t, DOCIDTYPE>> search_results_;
153
+ BaseMultiVectorSpace<DOCIDTYPE>& space_;
154
+
155
+ public:
156
+ MultiVectorSearchStopCondition(
157
+ BaseMultiVectorSpace<DOCIDTYPE>& space,
158
+ size_t num_docs_to_search,
159
+ size_t ef_collection = 10)
160
+ : space_(space) {
161
+ curr_num_docs_ = 0;
162
+ num_docs_to_search_ = num_docs_to_search;
163
+ ef_collection_ = std::max(ef_collection, num_docs_to_search);
164
+ }
165
+
166
+ void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
167
+ DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
168
+ if (doc_counter_[doc_id] == 0) {
169
+ curr_num_docs_ += 1;
170
+ }
171
+ search_results_.emplace(dist, doc_id);
172
+ doc_counter_[doc_id] += 1;
173
+ }
174
+
175
+ void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
176
+ DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
177
+ doc_counter_[doc_id] -= 1;
178
+ if (doc_counter_[doc_id] == 0) {
179
+ curr_num_docs_ -= 1;
180
+ }
181
+ search_results_.pop();
182
+ }
183
+
184
+ bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
185
+ bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_;
186
+ return stop_search;
187
+ }
188
+
189
+ bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
190
+ bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist;
191
+ return flag_consider_candidate;
192
+ }
193
+
194
+ bool should_remove_extra() override {
195
+ bool flag_remove_extra = curr_num_docs_ > ef_collection_;
196
+ return flag_remove_extra;
197
+ }
198
+
199
+ void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
200
+ while (curr_num_docs_ > num_docs_to_search_) {
201
+ dist_t dist_cand = candidates.back().first;
202
+ dist_t dist_res = search_results_.top().first;
203
+ assert(dist_cand == dist_res);
204
+ DOCIDTYPE doc_id = search_results_.top().second;
205
+ doc_counter_[doc_id] -= 1;
206
+ if (doc_counter_[doc_id] == 0) {
207
+ curr_num_docs_ -= 1;
208
+ }
209
+ search_results_.pop();
210
+ candidates.pop_back();
211
+ }
212
+ }
213
+
214
+ ~MultiVectorSearchStopCondition() {}
215
+ };
216
+
217
+
218
+ template<typename dist_t>
219
+ class EpsilonSearchStopCondition : public BaseSearchStopCondition<dist_t> {
220
+ float epsilon_;
221
+ size_t min_num_candidates_;
222
+ size_t max_num_candidates_;
223
+ size_t curr_num_items_;
224
+
225
+ public:
226
+ EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) {
227
+ assert(min_num_candidates <= max_num_candidates);
228
+ epsilon_ = epsilon;
229
+ min_num_candidates_ = min_num_candidates;
230
+ max_num_candidates_ = max_num_candidates;
231
+ curr_num_items_ = 0;
232
+ }
233
+
234
+ void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
235
+ curr_num_items_ += 1;
236
+ }
237
+
238
+ void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
239
+ curr_num_items_ -= 1;
240
+ }
241
+
242
+ bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
243
+ if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) {
244
+ // new candidate can't improve found results
245
+ return true;
246
+ }
247
+ if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) {
248
+ // new candidate is out of epsilon region and
249
+ // minimum number of candidates is checked
250
+ return true;
251
+ }
252
+ return false;
253
+ }
254
+
255
+ bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
256
+ bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist;
257
+ return flag_consider_candidate;
258
+ }
259
+
260
+ bool should_remove_extra() {
261
+ bool flag_remove_extra = curr_num_items_ > max_num_candidates_;
262
+ return flag_remove_extra;
263
+ }
264
+
265
+ void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
266
+ while (!candidates.empty() && candidates.back().first > epsilon_) {
267
+ candidates.pop_back();
268
+ }
269
+ while (candidates.size() > max_num_candidates_) {
270
+ candidates.pop_back();
271
+ }
272
+ }
273
+
274
+ ~EpsilonSearchStopCondition() {}
275
+ };
276
+ } // namespace hnswlib
@@ -0,0 +1,78 @@
1
+ #pragma once
2
+
3
+ #include <mutex>
4
+ #include <string.h>
5
+ #include <deque>
6
+
7
+ namespace hnswlib {
8
+ typedef unsigned short int vl_type;
9
+
10
+ class VisitedList {
11
+ public:
12
+ vl_type curV;
13
+ vl_type *mass;
14
+ unsigned int numelements;
15
+
16
+ VisitedList(int numelements1) {
17
+ curV = -1;
18
+ numelements = numelements1;
19
+ mass = new vl_type[numelements];
20
+ }
21
+
22
+ void reset() {
23
+ curV++;
24
+ if (curV == 0) {
25
+ memset(mass, 0, sizeof(vl_type) * numelements);
26
+ curV++;
27
+ }
28
+ }
29
+
30
+ ~VisitedList() { delete[] mass; }
31
+ };
32
+ ///////////////////////////////////////////////////////////
33
+ //
34
+ // Class for multi-threaded pool-management of VisitedLists
35
+ //
36
+ /////////////////////////////////////////////////////////
37
+
38
+ class VisitedListPool {
39
+ std::deque<VisitedList *> pool;
40
+ std::mutex poolguard;
41
+ int numelements;
42
+
43
+ public:
44
+ VisitedListPool(int initmaxpools, int numelements1) {
45
+ numelements = numelements1;
46
+ for (int i = 0; i < initmaxpools; i++)
47
+ pool.push_front(new VisitedList(numelements));
48
+ }
49
+
50
+ VisitedList *getFreeVisitedList() {
51
+ VisitedList *rez;
52
+ {
53
+ std::unique_lock <std::mutex> lock(poolguard);
54
+ if (pool.size() > 0) {
55
+ rez = pool.front();
56
+ pool.pop_front();
57
+ } else {
58
+ rez = new VisitedList(numelements);
59
+ }
60
+ }
61
+ rez->reset();
62
+ return rez;
63
+ }
64
+
65
+ void releaseVisitedList(VisitedList *vl) {
66
+ std::unique_lock <std::mutex> lock(poolguard);
67
+ pool.push_front(vl);
68
+ }
69
+
70
+ ~VisitedListPool() {
71
+ while (pool.size()) {
72
+ VisitedList *rez = pool.front();
73
+ pool.pop_front();
74
+ delete rez;
75
+ }
76
+ }
77
+ };
78
+ } // namespace hnswlib