hnswlib 0.6.2 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  #pragma once
2
2
  #ifndef NO_MANUAL_VECTORIZATION
3
- #ifdef __SSE__
3
+ #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
4
4
  #define USE_SSE
5
5
  #ifdef __AVX__
6
6
  #define USE_AVX
@@ -15,21 +15,20 @@
15
15
  #ifdef _MSC_VER
16
16
  #include <intrin.h>
17
17
  #include <stdexcept>
18
- #include "cpu_x86.h"
19
- void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
18
+ void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
20
19
  __cpuidex(out, eax, ecx);
21
20
  }
22
- __int64 xgetbv(unsigned int x) {
21
+ static __int64 xgetbv(unsigned int x) {
23
22
  return _xgetbv(x);
24
23
  }
25
24
  #else
26
25
  #include <x86intrin.h>
27
26
  #include <cpuid.h>
28
27
  #include <stdint.h>
29
- void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
28
+ static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
30
29
  __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
31
30
  }
32
- uint64_t xgetbv(unsigned int index) {
31
+ static uint64_t xgetbv(unsigned int index) {
33
32
  uint32_t eax, edx;
34
33
  __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
35
34
  return ((uint64_t)edx << 32) | eax;
@@ -51,7 +50,7 @@ uint64_t xgetbv(unsigned int index) {
51
50
  // Adapted from https://github.com/Mysticial/FeatureDetector
52
51
  #define _XCR_XFEATURE_ENABLED_MASK 0
53
52
 
54
- bool AVXCapable() {
53
+ static bool AVXCapable() {
55
54
  int cpuInfo[4];
56
55
 
57
56
  // CPU support
@@ -78,7 +77,7 @@ bool AVXCapable() {
78
77
  return HW_AVX && avxSupported;
79
78
  }
80
79
 
81
- bool AVX512Capable() {
80
+ static bool AVX512Capable() {
82
81
  if (!AVXCapable()) return false;
83
82
 
84
83
  int cpuInfo[4];
@@ -88,7 +87,7 @@ bool AVX512Capable() {
88
87
  int nIds = cpuInfo[0];
89
88
 
90
89
  bool HW_AVX512F = false;
91
- if (nIds >= 0x00000007) { // AVX512 Foundation
90
+ if (nIds >= 0x00000007) { // AVX512 Foundation
92
91
  cpuid(cpuInfo, 0x00000007, 0);
93
92
  HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
94
93
  }
@@ -114,77 +113,86 @@ bool AVX512Capable() {
114
113
  #include <string.h>
115
114
 
116
115
  namespace hnswlib {
117
- typedef size_t labeltype;
118
-
119
- template <typename T>
120
- class pairGreater {
121
- public:
122
- bool operator()(const T& p1, const T& p2) {
123
- return p1.first > p2.first;
124
- }
125
- };
126
-
127
- template<typename T>
128
- static void writeBinaryPOD(std::ostream &out, const T &podRef) {
129
- out.write((char *) &podRef, sizeof(T));
116
+ typedef size_t labeltype;
117
+
118
+ // This can be extended to store state for filtering (e.g. from a std::set)
119
+ class BaseFilterFunctor {
120
+ public:
121
+ virtual bool operator()(hnswlib::labeltype id) { return true; }
122
+ virtual ~BaseFilterFunctor() {};
123
+ };
124
+
125
+ template <typename T>
126
+ class pairGreater {
127
+ public:
128
+ bool operator()(const T& p1, const T& p2) {
129
+ return p1.first > p2.first;
130
130
  }
131
+ };
131
132
 
132
- template<typename T>
133
- static void readBinaryPOD(std::istream &in, T &podRef) {
134
- in.read((char *) &podRef, sizeof(T));
135
- }
133
+ template<typename T>
134
+ static void writeBinaryPOD(std::ostream &out, const T &podRef) {
135
+ out.write((char *) &podRef, sizeof(T));
136
+ }
136
137
 
137
- template<typename MTYPE>
138
- using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
138
+ template<typename T>
139
+ static void readBinaryPOD(std::istream &in, T &podRef) {
140
+ in.read((char *) &podRef, sizeof(T));
141
+ }
139
142
 
143
+ template<typename MTYPE>
144
+ using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
140
145
 
141
- template<typename MTYPE>
142
- class SpaceInterface {
143
- public:
144
- //virtual void search(void *);
145
- virtual size_t get_data_size() = 0;
146
+ template<typename MTYPE>
147
+ class SpaceInterface {
148
+ public:
149
+ // virtual void search(void *);
150
+ virtual size_t get_data_size() = 0;
146
151
 
147
- virtual DISTFUNC<MTYPE> get_dist_func() = 0;
152
+ virtual DISTFUNC<MTYPE> get_dist_func() = 0;
148
153
 
149
- virtual void *get_dist_func_param() = 0;
154
+ virtual void *get_dist_func_param() = 0;
150
155
 
151
- virtual ~SpaceInterface() {}
152
- };
156
+ virtual ~SpaceInterface() {}
157
+ };
153
158
 
154
- template<typename dist_t>
155
- class AlgorithmInterface {
156
- public:
157
- virtual void addPoint(const void *datapoint, labeltype label)=0;
158
- virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
159
+ template<typename dist_t>
160
+ class AlgorithmInterface {
161
+ public:
162
+ virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;
159
163
 
160
- // Return k nearest neighbor in the order of closer fist
161
- virtual std::vector<std::pair<dist_t, labeltype>>
162
- searchKnnCloserFirst(const void* query_data, size_t k) const;
164
+ virtual std::priority_queue<std::pair<dist_t, labeltype>>
165
+ searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
163
166
 
164
- virtual void saveIndex(const std::string &location)=0;
165
- virtual ~AlgorithmInterface(){
166
- }
167
- };
168
-
169
- template<typename dist_t>
170
- std::vector<std::pair<dist_t, labeltype>>
171
- AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
172
- std::vector<std::pair<dist_t, labeltype>> result;
173
-
174
- // here searchKnn returns the result in the order of further first
175
- auto ret = searchKnn(query_data, k);
176
- {
177
- size_t sz = ret.size();
178
- result.resize(sz);
179
- while (!ret.empty()) {
180
- result[--sz] = ret.top();
181
- ret.pop();
182
- }
183
- }
167
+ // Return k nearest neighbor in the order of closer fist
168
+ virtual std::vector<std::pair<dist_t, labeltype>>
169
+ searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
184
170
 
185
- return result;
171
+ virtual void saveIndex(const std::string &location) = 0;
172
+ virtual ~AlgorithmInterface(){
186
173
  }
174
+ };
175
+
176
+ template<typename dist_t>
177
+ std::vector<std::pair<dist_t, labeltype>>
178
+ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
179
+ BaseFilterFunctor* isIdAllowed) const {
180
+ std::vector<std::pair<dist_t, labeltype>> result;
181
+
182
+ // here searchKnn returns the result in the order of further first
183
+ auto ret = searchKnn(query_data, k, isIdAllowed);
184
+ {
185
+ size_t sz = ret.size();
186
+ result.resize(sz);
187
+ while (!ret.empty()) {
188
+ result[--sz] = ret.top();
189
+ ret.pop();
190
+ }
191
+ }
192
+
193
+ return result;
187
194
  }
195
+ } // namespace hnswlib
188
196
 
189
197
  #include "space_l2.h"
190
198
  #include "space_ip.h"