logosdb 0.7.7 → 0.7.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,228 @@
1
+ #pragma once
2
+
3
+ // https://github.com/nmslib/hnswlib/pull/508
4
+ // This allows others to provide their own error stream (e.g. RcppHNSW)
5
+ #ifndef HNSWLIB_ERR_OVERRIDE
6
+ #define HNSWERR std::cerr
7
+ #else
8
+ #define HNSWERR HNSWLIB_ERR_OVERRIDE
9
+ #endif
10
+
11
+ #ifndef NO_MANUAL_VECTORIZATION
12
+ #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
13
+ #define USE_SSE
14
+ #ifdef __AVX__
15
+ #define USE_AVX
16
+ #ifdef __AVX512F__
17
+ #define USE_AVX512
18
+ #endif
19
+ #endif
20
+ #endif
21
+ #endif
22
+
23
+ #if defined(USE_AVX) || defined(USE_SSE)
24
+ #ifdef _MSC_VER
25
+ #include <intrin.h>
26
+ #include <stdexcept>
27
+ static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
28
+ __cpuidex(out, eax, ecx);
29
+ }
30
+ static __int64 xgetbv(unsigned int x) {
31
+ return _xgetbv(x);
32
+ }
33
+ #else
34
+ #include <x86intrin.h>
35
+ #include <cpuid.h>
36
+ #include <stdint.h>
37
+ static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
38
+ __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
39
+ }
40
+ static uint64_t xgetbv(unsigned int index) {
41
+ uint32_t eax, edx;
42
+ __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
43
+ return ((uint64_t)edx << 32) | eax;
44
+ }
45
+ #endif
46
+
47
+ #if defined(USE_AVX512)
48
+ #include <immintrin.h>
49
+ #endif
50
+
51
+ #if defined(__GNUC__)
52
+ #define PORTABLE_ALIGN32 __attribute__((aligned(32)))
53
+ #define PORTABLE_ALIGN64 __attribute__((aligned(64)))
54
+ #else
55
+ #define PORTABLE_ALIGN32 __declspec(align(32))
56
+ #define PORTABLE_ALIGN64 __declspec(align(64))
57
+ #endif
58
+
59
+ // Adapted from https://github.com/Mysticial/FeatureDetector
60
+ #define _XCR_XFEATURE_ENABLED_MASK 0
61
+
62
+ static bool AVXCapable() {
63
+ int cpuInfo[4];
64
+
65
+ // CPU support
66
+ cpuid(cpuInfo, 0, 0);
67
+ int nIds = cpuInfo[0];
68
+
69
+ bool HW_AVX = false;
70
+ if (nIds >= 0x00000001) {
71
+ cpuid(cpuInfo, 0x00000001, 0);
72
+ HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
73
+ }
74
+
75
+ // OS support
76
+ cpuid(cpuInfo, 1, 0);
77
+
78
+ bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
79
+ bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
80
+
81
+ bool avxSupported = false;
82
+ if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
83
+ uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
84
+ avxSupported = (xcrFeatureMask & 0x6) == 0x6;
85
+ }
86
+ return HW_AVX && avxSupported;
87
+ }
88
+
89
+ static bool AVX512Capable() {
90
+ if (!AVXCapable()) return false;
91
+
92
+ int cpuInfo[4];
93
+
94
+ // CPU support
95
+ cpuid(cpuInfo, 0, 0);
96
+ int nIds = cpuInfo[0];
97
+
98
+ bool HW_AVX512F = false;
99
+ if (nIds >= 0x00000007) { // AVX512 Foundation
100
+ cpuid(cpuInfo, 0x00000007, 0);
101
+ HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
102
+ }
103
+
104
+ // OS support
105
+ cpuid(cpuInfo, 1, 0);
106
+
107
+ bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
108
+ bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
109
+
110
+ bool avx512Supported = false;
111
+ if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
112
+ uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
113
+ avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
114
+ }
115
+ return HW_AVX512F && avx512Supported;
116
+ }
117
+ #endif
118
+
119
+ #include <queue>
120
+ #include <vector>
121
+ #include <iostream>
122
+ #include <string.h>
123
+
124
+ namespace hnswlib {
125
+ typedef size_t labeltype;
126
+
127
+ // This can be extended to store state for filtering (e.g. from a std::set)
128
+ class BaseFilterFunctor {
129
+ public:
130
+ virtual bool operator()(hnswlib::labeltype id) { return true; }
131
+ virtual ~BaseFilterFunctor() {};
132
+ };
133
+
134
+ template<typename dist_t>
135
+ class BaseSearchStopCondition {
136
+ public:
137
+ virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0;
138
+
139
+ virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0;
140
+
141
+ virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0;
142
+
143
+ virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0;
144
+
145
+ virtual bool should_remove_extra() = 0;
146
+
147
+ virtual void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) = 0;
148
+
149
+ virtual ~BaseSearchStopCondition() {}
150
+ };
151
+
152
+ template <typename T>
153
+ class pairGreater {
154
+ public:
155
+ bool operator()(const T& p1, const T& p2) {
156
+ return p1.first > p2.first;
157
+ }
158
+ };
159
+
160
+ template<typename T>
161
+ static void writeBinaryPOD(std::ostream &out, const T &podRef) {
162
+ out.write((char *) &podRef, sizeof(T));
163
+ }
164
+
165
+ template<typename T>
166
+ static void readBinaryPOD(std::istream &in, T &podRef) {
167
+ in.read((char *) &podRef, sizeof(T));
168
+ }
169
+
170
+ template<typename MTYPE>
171
+ using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
172
+
173
+ template<typename MTYPE>
174
+ class SpaceInterface {
175
+ public:
176
+ // virtual void search(void *);
177
+ virtual size_t get_data_size() = 0;
178
+
179
+ virtual DISTFUNC<MTYPE> get_dist_func() = 0;
180
+
181
+ virtual void *get_dist_func_param() = 0;
182
+
183
+ virtual ~SpaceInterface() {}
184
+ };
185
+
186
+ template<typename dist_t>
187
+ class AlgorithmInterface {
188
+ public:
189
+ virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;
190
+
191
+ virtual std::priority_queue<std::pair<dist_t, labeltype>>
192
+ searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
193
+
194
+ // Return k nearest neighbor in the order of closer fist
195
+ virtual std::vector<std::pair<dist_t, labeltype>>
196
+ searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
197
+
198
+ virtual void saveIndex(const std::string &location) = 0;
199
+ virtual ~AlgorithmInterface(){
200
+ }
201
+ };
202
+
203
+ template<typename dist_t>
204
+ std::vector<std::pair<dist_t, labeltype>>
205
+ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
206
+ BaseFilterFunctor* isIdAllowed) const {
207
+ std::vector<std::pair<dist_t, labeltype>> result;
208
+
209
+ // here searchKnn returns the result in the order of further first
210
+ auto ret = searchKnn(query_data, k, isIdAllowed);
211
+ {
212
+ size_t sz = ret.size();
213
+ result.resize(sz);
214
+ while (!ret.empty()) {
215
+ result[--sz] = ret.top();
216
+ ret.pop();
217
+ }
218
+ }
219
+
220
+ return result;
221
+ }
222
+ } // namespace hnswlib
223
+
224
+ #include "space_l2.h"
225
+ #include "space_ip.h"
226
+ #include "stop_condition.h"
227
+ #include "bruteforce.h"
228
+ #include "hnswalg.h"
@@ -0,0 +1,400 @@
1
+ #pragma once
2
+ #include "hnswlib.h"
3
+
4
+ namespace hnswlib {
5
+
6
+ static float
7
+ InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
8
+ size_t qty = *((size_t *) qty_ptr);
9
+ float res = 0;
10
+ for (unsigned i = 0; i < qty; i++) {
11
+ res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
12
+ }
13
+ return res;
14
+ }
15
+
16
+ static float
17
+ InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
18
+ return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
19
+ }
20
+
21
+ #if defined(USE_AVX)
22
+
23
+ // Favor using AVX if available.
24
+ static float
25
+ InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
26
+ float PORTABLE_ALIGN32 TmpRes[8];
27
+ float *pVect1 = (float *) pVect1v;
28
+ float *pVect2 = (float *) pVect2v;
29
+ size_t qty = *((size_t *) qty_ptr);
30
+
31
+ size_t qty16 = qty / 16;
32
+ size_t qty4 = qty / 4;
33
+
34
+ const float *pEnd1 = pVect1 + 16 * qty16;
35
+ const float *pEnd2 = pVect1 + 4 * qty4;
36
+
37
+ __m256 sum256 = _mm256_set1_ps(0);
38
+
39
+ while (pVect1 < pEnd1) {
40
+ //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
41
+
42
+ __m256 v1 = _mm256_loadu_ps(pVect1);
43
+ pVect1 += 8;
44
+ __m256 v2 = _mm256_loadu_ps(pVect2);
45
+ pVect2 += 8;
46
+ sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
47
+
48
+ v1 = _mm256_loadu_ps(pVect1);
49
+ pVect1 += 8;
50
+ v2 = _mm256_loadu_ps(pVect2);
51
+ pVect2 += 8;
52
+ sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
53
+ }
54
+
55
+ __m128 v1, v2;
56
+ __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
57
+
58
+ while (pVect1 < pEnd2) {
59
+ v1 = _mm_loadu_ps(pVect1);
60
+ pVect1 += 4;
61
+ v2 = _mm_loadu_ps(pVect2);
62
+ pVect2 += 4;
63
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
64
+ }
65
+
66
+ _mm_store_ps(TmpRes, sum_prod);
67
+ float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
68
+ return sum;
69
+ }
70
+
71
+ static float
72
+ InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
73
+ return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr);
74
+ }
75
+
76
+ #endif
77
+
78
+ #if defined(USE_SSE)
79
+
80
+ static float
81
+ InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
82
+ float PORTABLE_ALIGN32 TmpRes[8];
83
+ float *pVect1 = (float *) pVect1v;
84
+ float *pVect2 = (float *) pVect2v;
85
+ size_t qty = *((size_t *) qty_ptr);
86
+
87
+ size_t qty16 = qty / 16;
88
+ size_t qty4 = qty / 4;
89
+
90
+ const float *pEnd1 = pVect1 + 16 * qty16;
91
+ const float *pEnd2 = pVect1 + 4 * qty4;
92
+
93
+ __m128 v1, v2;
94
+ __m128 sum_prod = _mm_set1_ps(0);
95
+
96
+ while (pVect1 < pEnd1) {
97
+ v1 = _mm_loadu_ps(pVect1);
98
+ pVect1 += 4;
99
+ v2 = _mm_loadu_ps(pVect2);
100
+ pVect2 += 4;
101
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
102
+
103
+ v1 = _mm_loadu_ps(pVect1);
104
+ pVect1 += 4;
105
+ v2 = _mm_loadu_ps(pVect2);
106
+ pVect2 += 4;
107
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
108
+
109
+ v1 = _mm_loadu_ps(pVect1);
110
+ pVect1 += 4;
111
+ v2 = _mm_loadu_ps(pVect2);
112
+ pVect2 += 4;
113
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
114
+
115
+ v1 = _mm_loadu_ps(pVect1);
116
+ pVect1 += 4;
117
+ v2 = _mm_loadu_ps(pVect2);
118
+ pVect2 += 4;
119
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
120
+ }
121
+
122
+ while (pVect1 < pEnd2) {
123
+ v1 = _mm_loadu_ps(pVect1);
124
+ pVect1 += 4;
125
+ v2 = _mm_loadu_ps(pVect2);
126
+ pVect2 += 4;
127
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
128
+ }
129
+
130
+ _mm_store_ps(TmpRes, sum_prod);
131
+ float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
132
+
133
+ return sum;
134
+ }
135
+
136
+ static float
137
+ InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
138
+ return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr);
139
+ }
140
+
141
+ #endif
142
+
143
+
144
+ #if defined(USE_AVX512)
145
+
146
+ static float
147
+ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
148
+ float PORTABLE_ALIGN64 TmpRes[16];
149
+ float *pVect1 = (float *) pVect1v;
150
+ float *pVect2 = (float *) pVect2v;
151
+ size_t qty = *((size_t *) qty_ptr);
152
+
153
+ size_t qty16 = qty / 16;
154
+
155
+
156
+ const float *pEnd1 = pVect1 + 16 * qty16;
157
+
158
+ __m512 sum512 = _mm512_set1_ps(0);
159
+
160
+ size_t loop = qty16 / 4;
161
+
162
+ while (loop--) {
163
+ __m512 v1 = _mm512_loadu_ps(pVect1);
164
+ __m512 v2 = _mm512_loadu_ps(pVect2);
165
+ pVect1 += 16;
166
+ pVect2 += 16;
167
+
168
+ __m512 v3 = _mm512_loadu_ps(pVect1);
169
+ __m512 v4 = _mm512_loadu_ps(pVect2);
170
+ pVect1 += 16;
171
+ pVect2 += 16;
172
+
173
+ __m512 v5 = _mm512_loadu_ps(pVect1);
174
+ __m512 v6 = _mm512_loadu_ps(pVect2);
175
+ pVect1 += 16;
176
+ pVect2 += 16;
177
+
178
+ __m512 v7 = _mm512_loadu_ps(pVect1);
179
+ __m512 v8 = _mm512_loadu_ps(pVect2);
180
+ pVect1 += 16;
181
+ pVect2 += 16;
182
+
183
+ sum512 = _mm512_fmadd_ps(v1, v2, sum512);
184
+ sum512 = _mm512_fmadd_ps(v3, v4, sum512);
185
+ sum512 = _mm512_fmadd_ps(v5, v6, sum512);
186
+ sum512 = _mm512_fmadd_ps(v7, v8, sum512);
187
+ }
188
+
189
+ while (pVect1 < pEnd1) {
190
+ __m512 v1 = _mm512_loadu_ps(pVect1);
191
+ __m512 v2 = _mm512_loadu_ps(pVect2);
192
+ pVect1 += 16;
193
+ pVect2 += 16;
194
+ sum512 = _mm512_fmadd_ps(v1, v2, sum512);
195
+ }
196
+
197
+ float sum = _mm512_reduce_add_ps(sum512);
198
+ return sum;
199
+ }
200
+
201
+ static float
202
+ InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
203
+ return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr);
204
+ }
205
+
206
+ #endif
207
+
208
+ #if defined(USE_AVX)
209
+
210
+ static float
211
+ InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
212
+ float PORTABLE_ALIGN32 TmpRes[8];
213
+ float *pVect1 = (float *) pVect1v;
214
+ float *pVect2 = (float *) pVect2v;
215
+ size_t qty = *((size_t *) qty_ptr);
216
+
217
+ size_t qty16 = qty / 16;
218
+
219
+
220
+ const float *pEnd1 = pVect1 + 16 * qty16;
221
+
222
+ __m256 sum256 = _mm256_set1_ps(0);
223
+
224
+ while (pVect1 < pEnd1) {
225
+ //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
226
+
227
+ __m256 v1 = _mm256_loadu_ps(pVect1);
228
+ pVect1 += 8;
229
+ __m256 v2 = _mm256_loadu_ps(pVect2);
230
+ pVect2 += 8;
231
+ sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
232
+
233
+ v1 = _mm256_loadu_ps(pVect1);
234
+ pVect1 += 8;
235
+ v2 = _mm256_loadu_ps(pVect2);
236
+ pVect2 += 8;
237
+ sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
238
+ }
239
+
240
+ _mm256_store_ps(TmpRes, sum256);
241
+ float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
242
+
243
+ return sum;
244
+ }
245
+
246
+ static float
247
+ InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
248
+ return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr);
249
+ }
250
+
251
+ #endif
252
+
253
+ #if defined(USE_SSE)
254
+
255
+ static float
256
+ InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
257
+ float PORTABLE_ALIGN32 TmpRes[8];
258
+ float *pVect1 = (float *) pVect1v;
259
+ float *pVect2 = (float *) pVect2v;
260
+ size_t qty = *((size_t *) qty_ptr);
261
+
262
+ size_t qty16 = qty / 16;
263
+
264
+ const float *pEnd1 = pVect1 + 16 * qty16;
265
+
266
+ __m128 v1, v2;
267
+ __m128 sum_prod = _mm_set1_ps(0);
268
+
269
+ while (pVect1 < pEnd1) {
270
+ v1 = _mm_loadu_ps(pVect1);
271
+ pVect1 += 4;
272
+ v2 = _mm_loadu_ps(pVect2);
273
+ pVect2 += 4;
274
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
275
+
276
+ v1 = _mm_loadu_ps(pVect1);
277
+ pVect1 += 4;
278
+ v2 = _mm_loadu_ps(pVect2);
279
+ pVect2 += 4;
280
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
281
+
282
+ v1 = _mm_loadu_ps(pVect1);
283
+ pVect1 += 4;
284
+ v2 = _mm_loadu_ps(pVect2);
285
+ pVect2 += 4;
286
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
287
+
288
+ v1 = _mm_loadu_ps(pVect1);
289
+ pVect1 += 4;
290
+ v2 = _mm_loadu_ps(pVect2);
291
+ pVect2 += 4;
292
+ sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
293
+ }
294
+ _mm_store_ps(TmpRes, sum_prod);
295
+ float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
296
+
297
+ return sum;
298
+ }
299
+
300
+ static float
301
+ InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
302
+ return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr);
303
+ }
304
+
305
+ #endif
306
+
307
+ #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
308
+ static DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
309
+ static DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
310
+ static DISTFUNC<float> InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
311
+ static DISTFUNC<float> InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
312
+
313
+ static float
314
+ InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
315
+ size_t qty = *((size_t *) qty_ptr);
316
+ size_t qty16 = qty >> 4 << 4;
317
+ float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
318
+ float *pVect1 = (float *) pVect1v + qty16;
319
+ float *pVect2 = (float *) pVect2v + qty16;
320
+
321
+ size_t qty_left = qty - qty16;
322
+ float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
323
+ return 1.0f - (res + res_tail);
324
+ }
325
+
326
+ static float
327
+ InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
328
+ size_t qty = *((size_t *) qty_ptr);
329
+ size_t qty4 = qty >> 2 << 2;
330
+
331
+ float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
332
+ size_t qty_left = qty - qty4;
333
+
334
+ float *pVect1 = (float *) pVect1v + qty4;
335
+ float *pVect2 = (float *) pVect2v + qty4;
336
+ float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
337
+
338
+ return 1.0f - (res + res_tail);
339
+ }
340
+ #endif
341
+
342
+ class InnerProductSpace : public SpaceInterface<float> {
343
+ DISTFUNC<float> fstdistfunc_;
344
+ size_t data_size_;
345
+ size_t dim_;
346
+
347
+ public:
348
+ InnerProductSpace(size_t dim) {
349
+ fstdistfunc_ = InnerProductDistance;
350
+ #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
351
+ #if defined(USE_AVX512)
352
+ if (AVX512Capable()) {
353
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
354
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
355
+ } else if (AVXCapable()) {
356
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
357
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
358
+ }
359
+ #elif defined(USE_AVX)
360
+ if (AVXCapable()) {
361
+ InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
362
+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
363
+ }
364
+ #endif
365
+ #if defined(USE_AVX)
366
+ if (AVXCapable()) {
367
+ InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
368
+ InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
369
+ }
370
+ #endif
371
+
372
+ if (dim % 16 == 0)
373
+ fstdistfunc_ = InnerProductDistanceSIMD16Ext;
374
+ else if (dim % 4 == 0)
375
+ fstdistfunc_ = InnerProductDistanceSIMD4Ext;
376
+ else if (dim > 16)
377
+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
378
+ else if (dim > 4)
379
+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
380
+ #endif
381
+ dim_ = dim;
382
+ data_size_ = dim * sizeof(float);
383
+ }
384
+
385
+ size_t get_data_size() {
386
+ return data_size_;
387
+ }
388
+
389
+ DISTFUNC<float> get_dist_func() {
390
+ return fstdistfunc_;
391
+ }
392
+
393
+ void *get_dist_func_param() {
394
+ return &dim_;
395
+ }
396
+
397
+ ~InnerProductSpace() {}
398
+ };
399
+
400
+ } // namespace hnswlib