faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -7,17 +7,14 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #include <string>
10
+ #include <algorithm>
11
11
  #include <cstdint>
12
12
  #include <cstring>
13
- #include <functional>
14
- #include <algorithm>
13
+ #include <string>
15
14
 
16
15
  namespace faiss {
17
16
 
18
-
19
17
  struct simd256bit {
20
-
21
18
  union {
22
19
  uint8_t u8[32];
23
20
  uint16_t u16[16];
@@ -27,8 +24,7 @@ struct simd256bit {
27
24
 
28
25
  simd256bit() {}
29
26
 
30
- explicit simd256bit(const void *x)
31
- {
27
+ explicit simd256bit(const void* x) {
32
28
  memcpy(u8, x, 32);
33
29
  }
34
30
 
@@ -36,20 +32,20 @@ struct simd256bit {
36
32
  memset(u8, 0, 32);
37
33
  }
38
34
 
39
- void storeu(void *ptr) const {
35
+ void storeu(void* ptr) const {
40
36
  memcpy(ptr, u8, 32);
41
37
  }
42
38
 
43
- void loadu(const void *ptr) {
39
+ void loadu(const void* ptr) {
44
40
  memcpy(u8, ptr, 32);
45
41
  }
46
42
 
47
- void store(void *ptr) const {
43
+ void store(void* ptr) const {
48
44
  storeu(ptr);
49
45
  }
50
46
 
51
47
  void bin(char bits[257]) const {
52
- const char *bytes = (char*)this->u8;
48
+ const char* bytes = (char*)this->u8;
53
49
  for (int i = 0; i < 256; i++) {
54
50
  bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
55
51
  }
@@ -61,14 +57,10 @@ struct simd256bit {
61
57
  bin(bits);
62
58
  return std::string(bits);
63
59
  }
64
-
65
60
  };
66
61
 
67
-
68
-
69
-
70
62
  /// vector of 16 elements in uint16
71
- struct simd16uint16: simd256bit {
63
+ struct simd16uint16 : simd256bit {
72
64
  simd16uint16() {}
73
65
 
74
66
  explicit simd16uint16(int x) {
@@ -79,13 +71,13 @@ struct simd16uint16: simd256bit {
79
71
  set1(x);
80
72
  }
81
73
 
82
- explicit simd16uint16(simd256bit x): simd256bit(x) {}
74
+ explicit simd16uint16(const simd256bit& x) : simd256bit(x) {}
83
75
 
84
- explicit simd16uint16(const uint16_t *x): simd256bit((const void*)x) {}
76
+ explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
85
77
 
86
- std::string elements_to_string(const char * fmt) const {
78
+ std::string elements_to_string(const char* fmt) const {
87
79
  char res[1000], *ptr = res;
88
- for(int i = 0; i < 16; i++) {
80
+ for (int i = 0; i < 16; i++) {
89
81
  ptr += sprintf(ptr, fmt, u16[i]);
90
82
  }
91
83
  // strip last ,
@@ -101,88 +93,86 @@ struct simd16uint16: simd256bit {
101
93
  return elements_to_string("%3d,");
102
94
  }
103
95
 
104
- static simd16uint16 unary_func(
105
- simd16uint16 a, std::function<uint16_t (uint16_t)> f)
106
- {
96
+ template <typename F>
97
+ static simd16uint16 unary_func(const simd16uint16& a, F&& f) {
107
98
  simd16uint16 c;
108
- for(int j = 0; j < 16; j++) {
99
+ for (int j = 0; j < 16; j++) {
109
100
  c.u16[j] = f(a.u16[j]);
110
101
  }
111
102
  return c;
112
103
  }
113
104
 
114
-
105
+ template <typename F>
115
106
  static simd16uint16 binary_func(
116
- simd16uint16 a, simd16uint16 b,
117
- std::function<uint16_t (uint16_t, uint16_t)> f)
118
- {
107
+ const simd16uint16& a,
108
+ const simd16uint16& b,
109
+ F&& f) {
119
110
  simd16uint16 c;
120
- for(int j = 0; j < 16; j++) {
111
+ for (int j = 0; j < 16; j++) {
121
112
  c.u16[j] = f(a.u16[j], b.u16[j]);
122
113
  }
123
114
  return c;
124
115
  }
125
116
 
126
117
  void set1(uint16_t x) {
127
- for(int i = 0; i < 16; i++) {
118
+ for (int i = 0; i < 16; i++) {
128
119
  u16[i] = x;
129
120
  }
130
121
  }
131
122
 
132
123
  // shift must be known at compile time
133
- simd16uint16 operator >> (const int shift) const {
134
- return unary_func(*this, [shift](uint16_t a) {return a >> shift; });
124
+ simd16uint16 operator>>(const int shift) const {
125
+ return unary_func(*this, [shift](uint16_t a) { return a >> shift; });
135
126
  }
136
127
 
137
-
138
128
  // shift must be known at compile time
139
- simd16uint16 operator << (const int shift) const {
140
- return unary_func(*this, [shift](uint16_t a) {return a << shift; });
129
+ simd16uint16 operator<<(const int shift) const {
130
+ return unary_func(*this, [shift](uint16_t a) { return a << shift; });
141
131
  }
142
132
 
143
- simd16uint16 operator += (simd16uint16 other) {
133
+ simd16uint16 operator+=(const simd16uint16& other) {
144
134
  *this = *this + other;
145
135
  return *this;
146
136
  }
147
137
 
148
- simd16uint16 operator -= (simd16uint16 other) {
138
+ simd16uint16 operator-=(const simd16uint16& other) {
149
139
  *this = *this - other;
150
140
  return *this;
151
141
  }
152
142
 
153
- simd16uint16 operator + (simd16uint16 other) const {
154
- return binary_func(*this, other,
155
- [](uint16_t a, uint16_t b) {return a + b; }
156
- );
143
+ simd16uint16 operator+(const simd16uint16& other) const {
144
+ return binary_func(
145
+ *this, other, [](uint16_t a, uint16_t b) { return a + b; });
157
146
  }
158
147
 
159
- simd16uint16 operator - (simd16uint16 other) const {
160
- return binary_func(*this, other,
161
- [](uint16_t a, uint16_t b) {return a - b; }
162
- );
148
+ simd16uint16 operator-(const simd16uint16& other) const {
149
+ return binary_func(
150
+ *this, other, [](uint16_t a, uint16_t b) { return a - b; });
163
151
  }
164
152
 
165
- simd16uint16 operator & (simd256bit other) const {
166
- return binary_func(*this, simd16uint16(other),
167
- [](uint16_t a, uint16_t b) {return a & b; }
168
- );
153
+ simd16uint16 operator&(const simd256bit& other) const {
154
+ return binary_func(
155
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
156
+ return a & b;
157
+ });
169
158
  }
170
159
 
171
- simd16uint16 operator | (simd256bit other) const {
172
- return binary_func(*this, simd16uint16(other),
173
- [](uint16_t a, uint16_t b) {return a | b; }
174
- );
160
+ simd16uint16 operator|(const simd256bit& other) const {
161
+ return binary_func(
162
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
163
+ return a | b;
164
+ });
175
165
  }
176
166
 
177
167
  // returns binary masks
178
- simd16uint16 operator == (simd16uint16 other) const {
179
- return binary_func(*this, other,
180
- [](uint16_t a, uint16_t b) {return a == b ? 0xffff : 0; }
181
- );
168
+ simd16uint16 operator==(const simd16uint16& other) const {
169
+ return binary_func(*this, other, [](uint16_t a, uint16_t b) {
170
+ return a == b ? 0xffff : 0;
171
+ });
182
172
  }
183
173
 
184
- simd16uint16 operator ~() const {
185
- return unary_func(*this, [](uint16_t a) {return ~a; });
174
+ simd16uint16 operator~() const {
175
+ return unary_func(*this, [](uint16_t a) { return ~a; });
186
176
  }
187
177
 
188
178
  // get scalar at index 0
@@ -192,9 +182,9 @@ struct simd16uint16: simd256bit {
192
182
 
193
183
  // mask of elements where this >= thresh
194
184
  // 2 bit per component: 16 * 2 = 32 bit
195
- uint32_t ge_mask(simd16uint16 thresh) const {
185
+ uint32_t ge_mask(const simd16uint16& thresh) const {
196
186
  uint32_t gem = 0;
197
- for(int j = 0; j < 16; j++) {
187
+ for (int j = 0; j < 16; j++) {
198
188
  if (u16[j] >= thresh.u16[j]) {
199
189
  gem |= 3 << (j * 2);
200
190
  }
@@ -202,61 +192,57 @@ struct simd16uint16: simd256bit {
202
192
  return gem;
203
193
  }
204
194
 
205
- uint32_t le_mask(simd16uint16 thresh) const {
195
+ uint32_t le_mask(const simd16uint16& thresh) const {
206
196
  return thresh.ge_mask(*this);
207
197
  }
208
198
 
209
- uint32_t gt_mask(simd16uint16 thresh) const {
199
+ uint32_t gt_mask(const simd16uint16& thresh) const {
210
200
  return ~le_mask(thresh);
211
201
  }
212
202
 
213
- bool all_gt(simd16uint16 thresh) const {
203
+ bool all_gt(const simd16uint16& thresh) const {
214
204
  return le_mask(thresh) == 0;
215
205
  }
216
206
 
217
207
  // for debugging only
218
- uint16_t operator [] (int i) const {
208
+ uint16_t operator[](int i) const {
219
209
  return u16[i];
220
210
  }
221
211
 
222
- void accu_min(simd16uint16 incoming) {
223
- for(int j = 0; j < 16; j++) {
212
+ void accu_min(const simd16uint16& incoming) {
213
+ for (int j = 0; j < 16; j++) {
224
214
  if (incoming.u16[j] < u16[j]) {
225
215
  u16[j] = incoming.u16[j];
226
216
  }
227
217
  }
228
218
  }
229
219
 
230
- void accu_max(simd16uint16 incoming) {
231
- for(int j = 0; j < 16; j++) {
220
+ void accu_max(const simd16uint16& incoming) {
221
+ for (int j = 0; j < 16; j++) {
232
222
  if (incoming.u16[j] > u16[j]) {
233
223
  u16[j] = incoming.u16[j];
234
224
  }
235
225
  }
236
226
  }
237
-
238
227
  };
239
228
 
240
-
241
229
  // not really a std::min because it returns an elementwise min
242
- inline simd16uint16 min(simd16uint16 av, simd16uint16 bv) {
243
- return simd16uint16::binary_func(av, bv,
244
- [](uint16_t a, uint16_t b) {return std::min(a, b); }
245
- );
230
+ inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
231
+ return simd16uint16::binary_func(
232
+ av, bv, [](uint16_t a, uint16_t b) { return std::min(a, b); });
246
233
  }
247
234
 
248
- inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
249
- return simd16uint16::binary_func(av, bv,
250
- [](uint16_t a, uint16_t b) {return std::max(a, b); }
251
- );
235
+ inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
236
+ return simd16uint16::binary_func(
237
+ av, bv, [](uint16_t a, uint16_t b) { return std::max(a, b); });
252
238
  }
253
239
 
254
240
  // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
255
241
  // return (a0 + a1, b0 + b1)
256
242
  // TODO find a better name
257
- inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
243
+ inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
258
244
  simd16uint16 c;
259
- for(int j = 0; j < 8; j++) {
245
+ for (int j = 0; j < 8; j++) {
260
246
  c.u16[j] = a.u16[j] + a.u16[j + 8];
261
247
  c.u16[j + 8] = b.u16[j] + b.u16[j + 8];
262
248
  }
@@ -265,9 +251,12 @@ inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
265
251
 
266
252
  // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
267
253
  // of d0 and d1 with thr
268
- inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
254
+ inline uint32_t cmp_ge32(
255
+ const simd16uint16& d0,
256
+ const simd16uint16& d1,
257
+ const simd16uint16& thr) {
269
258
  uint32_t gem = 0;
270
- for(int j = 0; j < 16; j++) {
259
+ for (int j = 0; j < 16; j++) {
271
260
  if (d0.u16[j] >= thr.u16[j]) {
272
261
  gem |= 1 << j;
273
262
  }
@@ -278,10 +267,12 @@ inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
278
267
  return gem;
279
268
  }
280
269
 
281
-
282
- inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
270
+ inline uint32_t cmp_le32(
271
+ const simd16uint16& d0,
272
+ const simd16uint16& d1,
273
+ const simd16uint16& thr) {
283
274
  uint32_t gem = 0;
284
- for(int j = 0; j < 16; j++) {
275
+ for (int j = 0; j < 16; j++) {
285
276
  if (d0.u16[j] <= thr.u16[j]) {
286
277
  gem |= 1 << j;
287
278
  }
@@ -292,24 +283,25 @@ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
292
283
  return gem;
293
284
  }
294
285
 
295
-
296
-
297
286
  // vector of 32 unsigned 8-bit integers
298
- struct simd32uint8: simd256bit {
299
-
287
+ struct simd32uint8 : simd256bit {
300
288
  simd32uint8() {}
301
289
 
302
- explicit simd32uint8(int x) {set1(x); }
290
+ explicit simd32uint8(int x) {
291
+ set1(x);
292
+ }
303
293
 
304
- explicit simd32uint8(uint8_t x) {set1(x); }
294
+ explicit simd32uint8(uint8_t x) {
295
+ set1(x);
296
+ }
305
297
 
306
- explicit simd32uint8(simd256bit x): simd256bit(x) {}
298
+ explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
307
299
 
308
- explicit simd32uint8(const uint8_t *x): simd256bit((const void*)x) {}
300
+ explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
309
301
 
310
- std::string elements_to_string(const char * fmt) const {
302
+ std::string elements_to_string(const char* fmt) const {
311
303
  char res[1000], *ptr = res;
312
- for(int i = 0; i < 32; i++) {
304
+ for (int i = 0; i < 32; i++) {
313
305
  ptr += sprintf(ptr, fmt, u8[i]);
314
306
  }
315
307
  // strip last ,
@@ -326,39 +318,38 @@ struct simd32uint8: simd256bit {
326
318
  }
327
319
 
328
320
  void set1(uint8_t x) {
329
- for(int j = 0; j < 32; j++) {
321
+ for (int j = 0; j < 32; j++) {
330
322
  u8[j] = x;
331
323
  }
332
324
  }
333
325
 
326
+ template <typename F>
334
327
  static simd32uint8 binary_func(
335
- simd32uint8 a, simd32uint8 b,
336
- std::function<uint8_t (uint8_t, uint8_t)> f)
337
- {
328
+ const simd32uint8& a,
329
+ const simd32uint8& b,
330
+ F&& f) {
338
331
  simd32uint8 c;
339
- for(int j = 0; j < 32; j++) {
332
+ for (int j = 0; j < 32; j++) {
340
333
  c.u8[j] = f(a.u8[j], b.u8[j]);
341
334
  }
342
335
  return c;
343
336
  }
344
337
 
345
-
346
- simd32uint8 operator & (simd256bit other) const {
347
- return binary_func(*this, simd32uint8(other),
348
- [](uint8_t a, uint8_t b) {return a & b; }
349
- );
338
+ simd32uint8 operator&(const simd256bit& other) const {
339
+ return binary_func(*this, simd32uint8(other), [](uint8_t a, uint8_t b) {
340
+ return a & b;
341
+ });
350
342
  }
351
343
 
352
- simd32uint8 operator + (simd32uint8 other) const {
353
- return binary_func(*this, other,
354
- [](uint8_t a, uint8_t b) {return a + b; }
355
- );
344
+ simd32uint8 operator+(const simd32uint8& other) const {
345
+ return binary_func(
346
+ *this, other, [](uint8_t a, uint8_t b) { return a + b; });
356
347
  }
357
348
 
358
349
  // The very important operation that everything relies on
359
- simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
350
+ simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
360
351
  simd32uint8 c;
361
- for(int j = 0; j < 32; j++) {
352
+ for (int j = 0; j < 32; j++) {
362
353
  if (idx.u8[j] & 0x80) {
363
354
  c.u8[j] = 0;
364
355
  } else {
@@ -376,31 +367,29 @@ struct simd32uint8: simd256bit {
376
367
  // extract + 0-extend lane
377
368
  // this operation is slow (3 cycles)
378
369
 
379
- simd32uint8 operator += (simd32uint8 other) {
370
+ simd32uint8 operator+=(const simd32uint8& other) {
380
371
  *this = *this + other;
381
372
  return *this;
382
373
  }
383
374
 
384
375
  // for debugging only
385
- uint8_t operator [] (int i) const {
376
+ uint8_t operator[](int i) const {
386
377
  return u8[i];
387
378
  }
388
-
389
379
  };
390
380
 
391
-
392
381
  // convert with saturation
393
382
  // careful: this does not cross lanes, so the order is weird
394
- inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
383
+ inline simd32uint8 uint16_to_uint8_saturate(
384
+ const simd16uint16& a,
385
+ const simd16uint16& b) {
395
386
  simd32uint8 c;
396
387
 
397
- auto saturate_16_to_8 = [] (uint16_t x) {
398
- return x >= 256 ? 0xff : x;
399
- };
388
+ auto saturate_16_to_8 = [](uint16_t x) { return x >= 256 ? 0xff : x; };
400
389
 
401
390
  for (int i = 0; i < 8; i++) {
402
- c.u8[ i] = saturate_16_to_8(a.u16[i]);
403
- c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
391
+ c.u8[i] = saturate_16_to_8(a.u16[i]);
392
+ c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
404
393
  c.u8[16 + i] = saturate_16_to_8(a.u16[8 + i]);
405
394
  c.u8[24 + i] = saturate_16_to_8(b.u16[8 + i]);
406
395
  }
@@ -408,7 +397,7 @@ inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
408
397
  }
409
398
 
410
399
  /// get most significant bit of each byte
411
- inline uint32_t get_MSBs(simd32uint8 a) {
400
+ inline uint32_t get_MSBs(const simd32uint8& a) {
412
401
  uint32_t res = 0;
413
402
  for (int i = 0; i < 32; i++) {
414
403
  if (a.u8[i] & 0x80) {
@@ -419,7 +408,10 @@ inline uint32_t get_MSBs(simd32uint8 a) {
419
408
  }
420
409
 
421
410
  /// use MSB of each byte of mask to select a byte between a and b
422
- inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
411
+ inline simd32uint8 blendv(
412
+ const simd32uint8& a,
413
+ const simd32uint8& b,
414
+ const simd32uint8& mask) {
423
415
  simd32uint8 c;
424
416
  for (int i = 0; i < 32; i++) {
425
417
  if (mask.u8[i] & 0x80) {
@@ -431,23 +423,21 @@ inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
431
423
  return c;
432
424
  }
433
425
 
434
-
435
-
436
-
437
426
  /// vector of 8 unsigned 32-bit integers
438
- struct simd8uint32: simd256bit {
427
+ struct simd8uint32 : simd256bit {
439
428
  simd8uint32() {}
440
429
 
430
+ explicit simd8uint32(uint32_t x) {
431
+ set1(x);
432
+ }
441
433
 
442
- explicit simd8uint32(uint32_t x) {set1(x); }
443
-
444
- explicit simd8uint32(simd256bit x): simd256bit(x) {}
434
+ explicit simd8uint32(const simd256bit& x) : simd256bit(x) {}
445
435
 
446
- explicit simd8uint32(const uint8_t *x): simd256bit((const void*)x) {}
436
+ explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {}
447
437
 
448
- std::string elements_to_string(const char * fmt) const {
438
+ std::string elements_to_string(const char* fmt) const {
449
439
  char res[1000], *ptr = res;
450
- for(int i = 0; i < 8; i++) {
440
+ for (int i = 0; i < 8; i++) {
451
441
  ptr += sprintf(ptr, fmt, u32[i]);
452
442
  }
453
443
  // strip last ,
@@ -468,69 +458,67 @@ struct simd8uint32: simd256bit {
468
458
  u32[i] = x;
469
459
  }
470
460
  }
471
-
472
461
  };
473
462
 
474
- struct simd8float32: simd256bit {
475
-
463
+ struct simd8float32 : simd256bit {
476
464
  simd8float32() {}
477
465
 
478
- explicit simd8float32(simd256bit x): simd256bit(x) {}
466
+ explicit simd8float32(const simd256bit& x) : simd256bit(x) {}
479
467
 
480
- explicit simd8float32(float x) {set1(x); }
468
+ explicit simd8float32(float x) {
469
+ set1(x);
470
+ }
481
471
 
482
- explicit simd8float32(const float *x) {loadu((void*)x); }
472
+ explicit simd8float32(const float* x) {
473
+ loadu((void*)x);
474
+ }
483
475
 
484
476
  void set1(float x) {
485
- for(int i = 0; i < 8; i++) {
477
+ for (int i = 0; i < 8; i++) {
486
478
  f32[i] = x;
487
479
  }
488
480
  }
489
481
 
482
+ template <typename F>
490
483
  static simd8float32 binary_func(
491
- simd8float32 a, simd8float32 b,
492
- std::function<float (float, float)> f)
493
- {
484
+ const simd8float32& a,
485
+ const simd8float32& b,
486
+ F&& f) {
494
487
  simd8float32 c;
495
- for(int j = 0; j < 8; j++) {
488
+ for (int j = 0; j < 8; j++) {
496
489
  c.f32[j] = f(a.f32[j], b.f32[j]);
497
490
  }
498
491
  return c;
499
492
  }
500
493
 
501
- simd8float32 operator * (simd8float32 other) const {
502
- return binary_func(*this, other,
503
- [](float a, float b) {return a * b; }
504
- );
494
+ simd8float32 operator*(const simd8float32& other) const {
495
+ return binary_func(
496
+ *this, other, [](float a, float b) { return a * b; });
505
497
  }
506
498
 
507
- simd8float32 operator + (simd8float32 other) const {
508
- return binary_func(*this, other,
509
- [](float a, float b) {return a + b; }
510
- );
499
+ simd8float32 operator+(const simd8float32& other) const {
500
+ return binary_func(
501
+ *this, other, [](float a, float b) { return a + b; });
511
502
  }
512
503
 
513
- simd8float32 operator - (simd8float32 other) const {
514
- return binary_func(*this, other,
515
- [](float a, float b) {return a - b; }
516
- );
504
+ simd8float32 operator-(const simd8float32& other) const {
505
+ return binary_func(
506
+ *this, other, [](float a, float b) { return a - b; });
517
507
  }
518
508
 
519
509
  std::string tostring() const {
520
510
  char res[1000], *ptr = res;
521
- for(int i = 0; i < 8; i++) {
511
+ for (int i = 0; i < 8; i++) {
522
512
  ptr += sprintf(ptr, "%g,", f32[i]);
523
513
  }
524
514
  // strip last ,
525
515
  ptr[-1] = 0;
526
516
  return std::string(res);
527
517
  }
528
-
529
518
  };
530
519
 
531
-
532
520
  // hadd does not cross lanes
533
- inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
521
+ inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
534
522
  simd8float32 c;
535
523
  c.f32[0] = a.f32[0] + a.f32[1];
536
524
  c.f32[1] = a.f32[2] + a.f32[3];
@@ -545,7 +533,7 @@ inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
545
533
  return c;
546
534
  }
547
535
 
548
- inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
536
+ inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
549
537
  simd8float32 c;
550
538
  c.f32[0] = a.f32[0];
551
539
  c.f32[1] = b.f32[0];
@@ -560,7 +548,7 @@ inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
560
548
  return c;
561
549
  }
562
550
 
563
- inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
551
+ inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
564
552
  simd8float32 c;
565
553
  c.f32[0] = a.f32[2];
566
554
  c.f32[1] = b.f32[2];
@@ -576,14 +564,87 @@ inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
576
564
  }
577
565
 
578
566
  // compute a * b + c
579
- inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
567
+ inline simd8float32 fmadd(
568
+ const simd8float32& a,
569
+ const simd8float32& b,
570
+ const simd8float32& c) {
580
571
  simd8float32 res;
581
- for(int i = 0; i < 8; i++) {
572
+ for (int i = 0; i < 8; i++) {
582
573
  res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
583
574
  }
584
575
  return res;
585
576
  }
586
577
 
578
+ namespace {
579
+
580
+ // get even float32's of a and b, interleaved
581
+ simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
582
+ simd8float32 c;
583
+
584
+ c.f32[0] = a.f32[0];
585
+ c.f32[1] = a.f32[2];
586
+ c.f32[2] = b.f32[0];
587
+ c.f32[3] = b.f32[2];
588
+
589
+ c.f32[4] = a.f32[4];
590
+ c.f32[5] = a.f32[6];
591
+ c.f32[6] = b.f32[4];
592
+ c.f32[7] = b.f32[6];
593
+
594
+ return c;
595
+ }
596
+
597
+ // get odd float32's of a and b, interleaved
598
+ simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
599
+ simd8float32 c;
600
+
601
+ c.f32[0] = a.f32[1];
602
+ c.f32[1] = a.f32[3];
603
+ c.f32[2] = b.f32[1];
604
+ c.f32[3] = b.f32[3];
605
+
606
+ c.f32[4] = a.f32[5];
607
+ c.f32[5] = a.f32[7];
608
+ c.f32[6] = b.f32[5];
609
+ c.f32[7] = b.f32[7];
610
+
611
+ return c;
612
+ }
613
+
614
+ // 3 cycles
615
+ // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
616
+ simd8float32 getlow128(const simd8float32& a, const simd8float32& b) {
617
+ simd8float32 c;
618
+
619
+ c.f32[0] = a.f32[0];
620
+ c.f32[1] = a.f32[1];
621
+ c.f32[2] = a.f32[2];
622
+ c.f32[3] = a.f32[3];
623
+
624
+ c.f32[4] = b.f32[0];
625
+ c.f32[5] = b.f32[1];
626
+ c.f32[6] = b.f32[2];
627
+ c.f32[7] = b.f32[3];
628
+
629
+ return c;
630
+ }
631
+
632
+ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
633
+ simd8float32 c;
634
+
635
+ c.f32[0] = a.f32[4];
636
+ c.f32[1] = a.f32[5];
637
+ c.f32[2] = a.f32[6];
638
+ c.f32[3] = a.f32[7];
639
+
640
+ c.f32[4] = b.f32[4];
641
+ c.f32[5] = b.f32[5];
642
+ c.f32[6] = b.f32[6];
643
+ c.f32[7] = b.f32[7];
644
+
645
+ return c;
646
+ }
587
647
 
648
+ } // namespace
588
649
 
589
650
  } // namespace faiss