faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -7,17 +7,14 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #include <string>
10
+ #include <algorithm>
11
11
  #include <cstdint>
12
12
  #include <cstring>
13
- #include <functional>
14
- #include <algorithm>
13
+ #include <string>
15
14
 
16
15
  namespace faiss {
17
16
 
18
-
19
17
  struct simd256bit {
20
-
21
18
  union {
22
19
  uint8_t u8[32];
23
20
  uint16_t u16[16];
@@ -27,8 +24,7 @@ struct simd256bit {
27
24
 
28
25
  simd256bit() {}
29
26
 
30
- explicit simd256bit(const void *x)
31
- {
27
+ explicit simd256bit(const void* x) {
32
28
  memcpy(u8, x, 32);
33
29
  }
34
30
 
@@ -36,20 +32,20 @@ struct simd256bit {
36
32
  memset(u8, 0, 32);
37
33
  }
38
34
 
39
- void storeu(void *ptr) const {
35
+ void storeu(void* ptr) const {
40
36
  memcpy(ptr, u8, 32);
41
37
  }
42
38
 
43
- void loadu(const void *ptr) {
39
+ void loadu(const void* ptr) {
44
40
  memcpy(u8, ptr, 32);
45
41
  }
46
42
 
47
- void store(void *ptr) const {
43
+ void store(void* ptr) const {
48
44
  storeu(ptr);
49
45
  }
50
46
 
51
47
  void bin(char bits[257]) const {
52
- const char *bytes = (char*)this->u8;
48
+ const char* bytes = (char*)this->u8;
53
49
  for (int i = 0; i < 256; i++) {
54
50
  bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
55
51
  }
@@ -61,14 +57,10 @@ struct simd256bit {
61
57
  bin(bits);
62
58
  return std::string(bits);
63
59
  }
64
-
65
60
  };
66
61
 
67
-
68
-
69
-
70
62
  /// vector of 16 elements in uint16
71
- struct simd16uint16: simd256bit {
63
+ struct simd16uint16 : simd256bit {
72
64
  simd16uint16() {}
73
65
 
74
66
  explicit simd16uint16(int x) {
@@ -79,13 +71,13 @@ struct simd16uint16: simd256bit {
79
71
  set1(x);
80
72
  }
81
73
 
82
- explicit simd16uint16(simd256bit x): simd256bit(x) {}
74
+ explicit simd16uint16(const simd256bit& x) : simd256bit(x) {}
83
75
 
84
- explicit simd16uint16(const uint16_t *x): simd256bit((const void*)x) {}
76
+ explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
85
77
 
86
- std::string elements_to_string(const char * fmt) const {
78
+ std::string elements_to_string(const char* fmt) const {
87
79
  char res[1000], *ptr = res;
88
- for(int i = 0; i < 16; i++) {
80
+ for (int i = 0; i < 16; i++) {
89
81
  ptr += sprintf(ptr, fmt, u16[i]);
90
82
  }
91
83
  // strip last ,
@@ -101,88 +93,86 @@ struct simd16uint16: simd256bit {
101
93
  return elements_to_string("%3d,");
102
94
  }
103
95
 
104
- static simd16uint16 unary_func(
105
- simd16uint16 a, std::function<uint16_t (uint16_t)> f)
106
- {
96
+ template <typename F>
97
+ static simd16uint16 unary_func(const simd16uint16& a, F&& f) {
107
98
  simd16uint16 c;
108
- for(int j = 0; j < 16; j++) {
99
+ for (int j = 0; j < 16; j++) {
109
100
  c.u16[j] = f(a.u16[j]);
110
101
  }
111
102
  return c;
112
103
  }
113
104
 
114
-
105
+ template <typename F>
115
106
  static simd16uint16 binary_func(
116
- simd16uint16 a, simd16uint16 b,
117
- std::function<uint16_t (uint16_t, uint16_t)> f)
118
- {
107
+ const simd16uint16& a,
108
+ const simd16uint16& b,
109
+ F&& f) {
119
110
  simd16uint16 c;
120
- for(int j = 0; j < 16; j++) {
111
+ for (int j = 0; j < 16; j++) {
121
112
  c.u16[j] = f(a.u16[j], b.u16[j]);
122
113
  }
123
114
  return c;
124
115
  }
125
116
 
126
117
  void set1(uint16_t x) {
127
- for(int i = 0; i < 16; i++) {
118
+ for (int i = 0; i < 16; i++) {
128
119
  u16[i] = x;
129
120
  }
130
121
  }
131
122
 
132
123
  // shift must be known at compile time
133
- simd16uint16 operator >> (const int shift) const {
134
- return unary_func(*this, [shift](uint16_t a) {return a >> shift; });
124
+ simd16uint16 operator>>(const int shift) const {
125
+ return unary_func(*this, [shift](uint16_t a) { return a >> shift; });
135
126
  }
136
127
 
137
-
138
128
  // shift must be known at compile time
139
- simd16uint16 operator << (const int shift) const {
140
- return unary_func(*this, [shift](uint16_t a) {return a << shift; });
129
+ simd16uint16 operator<<(const int shift) const {
130
+ return unary_func(*this, [shift](uint16_t a) { return a << shift; });
141
131
  }
142
132
 
143
- simd16uint16 operator += (simd16uint16 other) {
133
+ simd16uint16 operator+=(const simd16uint16& other) {
144
134
  *this = *this + other;
145
135
  return *this;
146
136
  }
147
137
 
148
- simd16uint16 operator -= (simd16uint16 other) {
138
+ simd16uint16 operator-=(const simd16uint16& other) {
149
139
  *this = *this - other;
150
140
  return *this;
151
141
  }
152
142
 
153
- simd16uint16 operator + (simd16uint16 other) const {
154
- return binary_func(*this, other,
155
- [](uint16_t a, uint16_t b) {return a + b; }
156
- );
143
+ simd16uint16 operator+(const simd16uint16& other) const {
144
+ return binary_func(
145
+ *this, other, [](uint16_t a, uint16_t b) { return a + b; });
157
146
  }
158
147
 
159
- simd16uint16 operator - (simd16uint16 other) const {
160
- return binary_func(*this, other,
161
- [](uint16_t a, uint16_t b) {return a - b; }
162
- );
148
+ simd16uint16 operator-(const simd16uint16& other) const {
149
+ return binary_func(
150
+ *this, other, [](uint16_t a, uint16_t b) { return a - b; });
163
151
  }
164
152
 
165
- simd16uint16 operator & (simd256bit other) const {
166
- return binary_func(*this, simd16uint16(other),
167
- [](uint16_t a, uint16_t b) {return a & b; }
168
- );
153
+ simd16uint16 operator&(const simd256bit& other) const {
154
+ return binary_func(
155
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
156
+ return a & b;
157
+ });
169
158
  }
170
159
 
171
- simd16uint16 operator | (simd256bit other) const {
172
- return binary_func(*this, simd16uint16(other),
173
- [](uint16_t a, uint16_t b) {return a | b; }
174
- );
160
+ simd16uint16 operator|(const simd256bit& other) const {
161
+ return binary_func(
162
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
163
+ return a | b;
164
+ });
175
165
  }
176
166
 
177
167
  // returns binary masks
178
- simd16uint16 operator == (simd16uint16 other) const {
179
- return binary_func(*this, other,
180
- [](uint16_t a, uint16_t b) {return a == b ? 0xffff : 0; }
181
- );
168
+ simd16uint16 operator==(const simd16uint16& other) const {
169
+ return binary_func(*this, other, [](uint16_t a, uint16_t b) {
170
+ return a == b ? 0xffff : 0;
171
+ });
182
172
  }
183
173
 
184
- simd16uint16 operator ~() const {
185
- return unary_func(*this, [](uint16_t a) {return ~a; });
174
+ simd16uint16 operator~() const {
175
+ return unary_func(*this, [](uint16_t a) { return ~a; });
186
176
  }
187
177
 
188
178
  // get scalar at index 0
@@ -192,9 +182,9 @@ struct simd16uint16: simd256bit {
192
182
 
193
183
  // mask of elements where this >= thresh
194
184
  // 2 bit per component: 16 * 2 = 32 bit
195
- uint32_t ge_mask(simd16uint16 thresh) const {
185
+ uint32_t ge_mask(const simd16uint16& thresh) const {
196
186
  uint32_t gem = 0;
197
- for(int j = 0; j < 16; j++) {
187
+ for (int j = 0; j < 16; j++) {
198
188
  if (u16[j] >= thresh.u16[j]) {
199
189
  gem |= 3 << (j * 2);
200
190
  }
@@ -202,61 +192,57 @@ struct simd16uint16: simd256bit {
202
192
  return gem;
203
193
  }
204
194
 
205
- uint32_t le_mask(simd16uint16 thresh) const {
195
+ uint32_t le_mask(const simd16uint16& thresh) const {
206
196
  return thresh.ge_mask(*this);
207
197
  }
208
198
 
209
- uint32_t gt_mask(simd16uint16 thresh) const {
199
+ uint32_t gt_mask(const simd16uint16& thresh) const {
210
200
  return ~le_mask(thresh);
211
201
  }
212
202
 
213
- bool all_gt(simd16uint16 thresh) const {
203
+ bool all_gt(const simd16uint16& thresh) const {
214
204
  return le_mask(thresh) == 0;
215
205
  }
216
206
 
217
207
  // for debugging only
218
- uint16_t operator [] (int i) const {
208
+ uint16_t operator[](int i) const {
219
209
  return u16[i];
220
210
  }
221
211
 
222
- void accu_min(simd16uint16 incoming) {
223
- for(int j = 0; j < 16; j++) {
212
+ void accu_min(const simd16uint16& incoming) {
213
+ for (int j = 0; j < 16; j++) {
224
214
  if (incoming.u16[j] < u16[j]) {
225
215
  u16[j] = incoming.u16[j];
226
216
  }
227
217
  }
228
218
  }
229
219
 
230
- void accu_max(simd16uint16 incoming) {
231
- for(int j = 0; j < 16; j++) {
220
+ void accu_max(const simd16uint16& incoming) {
221
+ for (int j = 0; j < 16; j++) {
232
222
  if (incoming.u16[j] > u16[j]) {
233
223
  u16[j] = incoming.u16[j];
234
224
  }
235
225
  }
236
226
  }
237
-
238
227
  };
239
228
 
240
-
241
229
  // not really a std::min because it returns an elementwise min
242
- inline simd16uint16 min(simd16uint16 av, simd16uint16 bv) {
243
- return simd16uint16::binary_func(av, bv,
244
- [](uint16_t a, uint16_t b) {return std::min(a, b); }
245
- );
230
+ inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
231
+ return simd16uint16::binary_func(
232
+ av, bv, [](uint16_t a, uint16_t b) { return std::min(a, b); });
246
233
  }
247
234
 
248
- inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
249
- return simd16uint16::binary_func(av, bv,
250
- [](uint16_t a, uint16_t b) {return std::max(a, b); }
251
- );
235
+ inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
236
+ return simd16uint16::binary_func(
237
+ av, bv, [](uint16_t a, uint16_t b) { return std::max(a, b); });
252
238
  }
253
239
 
254
240
  // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
255
241
  // return (a0 + a1, b0 + b1)
256
242
  // TODO find a better name
257
- inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
243
+ inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
258
244
  simd16uint16 c;
259
- for(int j = 0; j < 8; j++) {
245
+ for (int j = 0; j < 8; j++) {
260
246
  c.u16[j] = a.u16[j] + a.u16[j + 8];
261
247
  c.u16[j + 8] = b.u16[j] + b.u16[j + 8];
262
248
  }
@@ -265,9 +251,12 @@ inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
265
251
 
266
252
  // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
267
253
  // of d0 and d1 with thr
268
- inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
254
+ inline uint32_t cmp_ge32(
255
+ const simd16uint16& d0,
256
+ const simd16uint16& d1,
257
+ const simd16uint16& thr) {
269
258
  uint32_t gem = 0;
270
- for(int j = 0; j < 16; j++) {
259
+ for (int j = 0; j < 16; j++) {
271
260
  if (d0.u16[j] >= thr.u16[j]) {
272
261
  gem |= 1 << j;
273
262
  }
@@ -278,10 +267,12 @@ inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
278
267
  return gem;
279
268
  }
280
269
 
281
-
282
- inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
270
+ inline uint32_t cmp_le32(
271
+ const simd16uint16& d0,
272
+ const simd16uint16& d1,
273
+ const simd16uint16& thr) {
283
274
  uint32_t gem = 0;
284
- for(int j = 0; j < 16; j++) {
275
+ for (int j = 0; j < 16; j++) {
285
276
  if (d0.u16[j] <= thr.u16[j]) {
286
277
  gem |= 1 << j;
287
278
  }
@@ -292,24 +283,25 @@ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
292
283
  return gem;
293
284
  }
294
285
 
295
-
296
-
297
286
  // vector of 32 unsigned 8-bit integers
298
- struct simd32uint8: simd256bit {
299
-
287
+ struct simd32uint8 : simd256bit {
300
288
  simd32uint8() {}
301
289
 
302
- explicit simd32uint8(int x) {set1(x); }
290
+ explicit simd32uint8(int x) {
291
+ set1(x);
292
+ }
303
293
 
304
- explicit simd32uint8(uint8_t x) {set1(x); }
294
+ explicit simd32uint8(uint8_t x) {
295
+ set1(x);
296
+ }
305
297
 
306
- explicit simd32uint8(simd256bit x): simd256bit(x) {}
298
+ explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
307
299
 
308
- explicit simd32uint8(const uint8_t *x): simd256bit((const void*)x) {}
300
+ explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
309
301
 
310
- std::string elements_to_string(const char * fmt) const {
302
+ std::string elements_to_string(const char* fmt) const {
311
303
  char res[1000], *ptr = res;
312
- for(int i = 0; i < 32; i++) {
304
+ for (int i = 0; i < 32; i++) {
313
305
  ptr += sprintf(ptr, fmt, u8[i]);
314
306
  }
315
307
  // strip last ,
@@ -326,39 +318,38 @@ struct simd32uint8: simd256bit {
326
318
  }
327
319
 
328
320
  void set1(uint8_t x) {
329
- for(int j = 0; j < 32; j++) {
321
+ for (int j = 0; j < 32; j++) {
330
322
  u8[j] = x;
331
323
  }
332
324
  }
333
325
 
326
+ template <typename F>
334
327
  static simd32uint8 binary_func(
335
- simd32uint8 a, simd32uint8 b,
336
- std::function<uint8_t (uint8_t, uint8_t)> f)
337
- {
328
+ const simd32uint8& a,
329
+ const simd32uint8& b,
330
+ F&& f) {
338
331
  simd32uint8 c;
339
- for(int j = 0; j < 32; j++) {
332
+ for (int j = 0; j < 32; j++) {
340
333
  c.u8[j] = f(a.u8[j], b.u8[j]);
341
334
  }
342
335
  return c;
343
336
  }
344
337
 
345
-
346
- simd32uint8 operator & (simd256bit other) const {
347
- return binary_func(*this, simd32uint8(other),
348
- [](uint8_t a, uint8_t b) {return a & b; }
349
- );
338
+ simd32uint8 operator&(const simd256bit& other) const {
339
+ return binary_func(*this, simd32uint8(other), [](uint8_t a, uint8_t b) {
340
+ return a & b;
341
+ });
350
342
  }
351
343
 
352
- simd32uint8 operator + (simd32uint8 other) const {
353
- return binary_func(*this, other,
354
- [](uint8_t a, uint8_t b) {return a + b; }
355
- );
344
+ simd32uint8 operator+(const simd32uint8& other) const {
345
+ return binary_func(
346
+ *this, other, [](uint8_t a, uint8_t b) { return a + b; });
356
347
  }
357
348
 
358
349
  // The very important operation that everything relies on
359
- simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
350
+ simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
360
351
  simd32uint8 c;
361
- for(int j = 0; j < 32; j++) {
352
+ for (int j = 0; j < 32; j++) {
362
353
  if (idx.u8[j] & 0x80) {
363
354
  c.u8[j] = 0;
364
355
  } else {
@@ -376,31 +367,29 @@ struct simd32uint8: simd256bit {
376
367
  // extract + 0-extend lane
377
368
  // this operation is slow (3 cycles)
378
369
 
379
- simd32uint8 operator += (simd32uint8 other) {
370
+ simd32uint8 operator+=(const simd32uint8& other) {
380
371
  *this = *this + other;
381
372
  return *this;
382
373
  }
383
374
 
384
375
  // for debugging only
385
- uint8_t operator [] (int i) const {
376
+ uint8_t operator[](int i) const {
386
377
  return u8[i];
387
378
  }
388
-
389
379
  };
390
380
 
391
-
392
381
  // convert with saturation
393
382
  // careful: this does not cross lanes, so the order is weird
394
- inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
383
+ inline simd32uint8 uint16_to_uint8_saturate(
384
+ const simd16uint16& a,
385
+ const simd16uint16& b) {
395
386
  simd32uint8 c;
396
387
 
397
- auto saturate_16_to_8 = [] (uint16_t x) {
398
- return x >= 256 ? 0xff : x;
399
- };
388
+ auto saturate_16_to_8 = [](uint16_t x) { return x >= 256 ? 0xff : x; };
400
389
 
401
390
  for (int i = 0; i < 8; i++) {
402
- c.u8[ i] = saturate_16_to_8(a.u16[i]);
403
- c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
391
+ c.u8[i] = saturate_16_to_8(a.u16[i]);
392
+ c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
404
393
  c.u8[16 + i] = saturate_16_to_8(a.u16[8 + i]);
405
394
  c.u8[24 + i] = saturate_16_to_8(b.u16[8 + i]);
406
395
  }
@@ -408,7 +397,7 @@ inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
408
397
  }
409
398
 
410
399
  /// get most significant bit of each byte
411
- inline uint32_t get_MSBs(simd32uint8 a) {
400
+ inline uint32_t get_MSBs(const simd32uint8& a) {
412
401
  uint32_t res = 0;
413
402
  for (int i = 0; i < 32; i++) {
414
403
  if (a.u8[i] & 0x80) {
@@ -419,7 +408,10 @@ inline uint32_t get_MSBs(simd32uint8 a) {
419
408
  }
420
409
 
421
410
  /// use MSB of each byte of mask to select a byte between a and b
422
- inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
411
+ inline simd32uint8 blendv(
412
+ const simd32uint8& a,
413
+ const simd32uint8& b,
414
+ const simd32uint8& mask) {
423
415
  simd32uint8 c;
424
416
  for (int i = 0; i < 32; i++) {
425
417
  if (mask.u8[i] & 0x80) {
@@ -431,23 +423,21 @@ inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
431
423
  return c;
432
424
  }
433
425
 
434
-
435
-
436
-
437
426
  /// vector of 8 unsigned 32-bit integers
438
- struct simd8uint32: simd256bit {
427
+ struct simd8uint32 : simd256bit {
439
428
  simd8uint32() {}
440
429
 
430
+ explicit simd8uint32(uint32_t x) {
431
+ set1(x);
432
+ }
441
433
 
442
- explicit simd8uint32(uint32_t x) {set1(x); }
443
-
444
- explicit simd8uint32(simd256bit x): simd256bit(x) {}
434
+ explicit simd8uint32(const simd256bit& x) : simd256bit(x) {}
445
435
 
446
- explicit simd8uint32(const uint8_t *x): simd256bit((const void*)x) {}
436
+ explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {}
447
437
 
448
- std::string elements_to_string(const char * fmt) const {
438
+ std::string elements_to_string(const char* fmt) const {
449
439
  char res[1000], *ptr = res;
450
- for(int i = 0; i < 8; i++) {
440
+ for (int i = 0; i < 8; i++) {
451
441
  ptr += sprintf(ptr, fmt, u32[i]);
452
442
  }
453
443
  // strip last ,
@@ -468,69 +458,67 @@ struct simd8uint32: simd256bit {
468
458
  u32[i] = x;
469
459
  }
470
460
  }
471
-
472
461
  };
473
462
 
474
- struct simd8float32: simd256bit {
475
-
463
+ struct simd8float32 : simd256bit {
476
464
  simd8float32() {}
477
465
 
478
- explicit simd8float32(simd256bit x): simd256bit(x) {}
466
+ explicit simd8float32(const simd256bit& x) : simd256bit(x) {}
479
467
 
480
- explicit simd8float32(float x) {set1(x); }
468
+ explicit simd8float32(float x) {
469
+ set1(x);
470
+ }
481
471
 
482
- explicit simd8float32(const float *x) {loadu((void*)x); }
472
+ explicit simd8float32(const float* x) {
473
+ loadu((void*)x);
474
+ }
483
475
 
484
476
  void set1(float x) {
485
- for(int i = 0; i < 8; i++) {
477
+ for (int i = 0; i < 8; i++) {
486
478
  f32[i] = x;
487
479
  }
488
480
  }
489
481
 
482
+ template <typename F>
490
483
  static simd8float32 binary_func(
491
- simd8float32 a, simd8float32 b,
492
- std::function<float (float, float)> f)
493
- {
484
+ const simd8float32& a,
485
+ const simd8float32& b,
486
+ F&& f) {
494
487
  simd8float32 c;
495
- for(int j = 0; j < 8; j++) {
488
+ for (int j = 0; j < 8; j++) {
496
489
  c.f32[j] = f(a.f32[j], b.f32[j]);
497
490
  }
498
491
  return c;
499
492
  }
500
493
 
501
- simd8float32 operator * (simd8float32 other) const {
502
- return binary_func(*this, other,
503
- [](float a, float b) {return a * b; }
504
- );
494
+ simd8float32 operator*(const simd8float32& other) const {
495
+ return binary_func(
496
+ *this, other, [](float a, float b) { return a * b; });
505
497
  }
506
498
 
507
- simd8float32 operator + (simd8float32 other) const {
508
- return binary_func(*this, other,
509
- [](float a, float b) {return a + b; }
510
- );
499
+ simd8float32 operator+(const simd8float32& other) const {
500
+ return binary_func(
501
+ *this, other, [](float a, float b) { return a + b; });
511
502
  }
512
503
 
513
- simd8float32 operator - (simd8float32 other) const {
514
- return binary_func(*this, other,
515
- [](float a, float b) {return a - b; }
516
- );
504
+ simd8float32 operator-(const simd8float32& other) const {
505
+ return binary_func(
506
+ *this, other, [](float a, float b) { return a - b; });
517
507
  }
518
508
 
519
509
  std::string tostring() const {
520
510
  char res[1000], *ptr = res;
521
- for(int i = 0; i < 8; i++) {
511
+ for (int i = 0; i < 8; i++) {
522
512
  ptr += sprintf(ptr, "%g,", f32[i]);
523
513
  }
524
514
  // strip last ,
525
515
  ptr[-1] = 0;
526
516
  return std::string(res);
527
517
  }
528
-
529
518
  };
530
519
 
531
-
532
520
  // hadd does not cross lanes
533
- inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
521
+ inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
534
522
  simd8float32 c;
535
523
  c.f32[0] = a.f32[0] + a.f32[1];
536
524
  c.f32[1] = a.f32[2] + a.f32[3];
@@ -545,7 +533,7 @@ inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
545
533
  return c;
546
534
  }
547
535
 
548
- inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
536
+ inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
549
537
  simd8float32 c;
550
538
  c.f32[0] = a.f32[0];
551
539
  c.f32[1] = b.f32[0];
@@ -560,7 +548,7 @@ inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
560
548
  return c;
561
549
  }
562
550
 
563
- inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
551
+ inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
564
552
  simd8float32 c;
565
553
  c.f32[0] = a.f32[2];
566
554
  c.f32[1] = b.f32[2];
@@ -576,14 +564,87 @@ inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
576
564
  }
577
565
 
578
566
  // compute a * b + c
579
- inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
567
+ inline simd8float32 fmadd(
568
+ const simd8float32& a,
569
+ const simd8float32& b,
570
+ const simd8float32& c) {
580
571
  simd8float32 res;
581
- for(int i = 0; i < 8; i++) {
572
+ for (int i = 0; i < 8; i++) {
582
573
  res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
583
574
  }
584
575
  return res;
585
576
  }
586
577
 
578
+ namespace {
579
+
580
+ // get even float32's of a and b, interleaved
581
+ simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
582
+ simd8float32 c;
583
+
584
+ c.f32[0] = a.f32[0];
585
+ c.f32[1] = a.f32[2];
586
+ c.f32[2] = b.f32[0];
587
+ c.f32[3] = b.f32[2];
588
+
589
+ c.f32[4] = a.f32[4];
590
+ c.f32[5] = a.f32[6];
591
+ c.f32[6] = b.f32[4];
592
+ c.f32[7] = b.f32[6];
593
+
594
+ return c;
595
+ }
596
+
597
+ // get odd float32's of a and b, interleaved
598
+ simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
599
+ simd8float32 c;
600
+
601
+ c.f32[0] = a.f32[1];
602
+ c.f32[1] = a.f32[3];
603
+ c.f32[2] = b.f32[1];
604
+ c.f32[3] = b.f32[3];
605
+
606
+ c.f32[4] = a.f32[5];
607
+ c.f32[5] = a.f32[7];
608
+ c.f32[6] = b.f32[5];
609
+ c.f32[7] = b.f32[7];
610
+
611
+ return c;
612
+ }
613
+
614
+ // 3 cycles
615
+ // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
616
+ simd8float32 getlow128(const simd8float32& a, const simd8float32& b) {
617
+ simd8float32 c;
618
+
619
+ c.f32[0] = a.f32[0];
620
+ c.f32[1] = a.f32[1];
621
+ c.f32[2] = a.f32[2];
622
+ c.f32[3] = a.f32[3];
623
+
624
+ c.f32[4] = b.f32[0];
625
+ c.f32[5] = b.f32[1];
626
+ c.f32[6] = b.f32[2];
627
+ c.f32[7] = b.f32[3];
628
+
629
+ return c;
630
+ }
631
+
632
+ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
633
+ simd8float32 c;
634
+
635
+ c.f32[0] = a.f32[4];
636
+ c.f32[1] = a.f32[5];
637
+ c.f32[2] = a.f32[6];
638
+ c.f32[3] = a.f32[7];
639
+
640
+ c.f32[4] = b.f32[4];
641
+ c.f32[5] = b.f32[5];
642
+ c.f32[6] = b.f32[6];
643
+ c.f32[7] = b.f32[7];
644
+
645
+ return c;
646
+ }
587
647
 
648
+ } // namespace
588
649
 
589
650
  } // namespace faiss