faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,9 +9,9 @@
9
9
  #ifndef FAISS_LATTICE_ZN_H
10
10
  #define FAISS_LATTICE_ZN_H
11
11
 
12
- #include <vector>
13
12
  #include <stddef.h>
14
13
  #include <stdint.h>
14
+ #include <vector>
15
15
 
16
16
  namespace faiss {
17
17
 
@@ -32,23 +32,20 @@ struct ZnSphereSearch {
32
32
  ZnSphereSearch(int dim, int r2);
33
33
 
34
34
  /// find nearest centroid. x does not need to be normalized
35
- float search(const float *x, float *c) const;
35
+ float search(const float* x, float* c) const;
36
36
 
37
37
  /// full call. Requires externally-allocated temp space
38
- float search(const float *x, float *c,
39
- float *tmp, // size 2 *dim
40
- int *tmp_int, // size dim
41
- int *ibest_out = nullptr
42
- ) const;
38
+ float search(
39
+ const float* x,
40
+ float* c,
41
+ float* tmp, // size 2 *dim
42
+ int* tmp_int, // size dim
43
+ int* ibest_out = nullptr) const;
43
44
 
44
45
  // multi-threaded
45
- void search_multi(int n, const float *x,
46
- float *c_out,
47
- float *dp_out);
48
-
46
+ void search_multi(int n, const float* x, float* c_out, float* dp_out);
49
47
  };
50
48
 
51
-
52
49
  /***************************************************************************
53
50
  * Support ids as well.
54
51
  *
@@ -60,30 +57,31 @@ struct EnumeratedVectors {
60
57
  uint64_t nv;
61
58
  int dim;
62
59
 
63
- explicit EnumeratedVectors(int dim): nv(0), dim(dim) {}
60
+ explicit EnumeratedVectors(int dim) : nv(0), dim(dim) {}
64
61
 
65
62
  /// encode a vector from a collection
66
- virtual uint64_t encode(const float *x) const = 0;
63
+ virtual uint64_t encode(const float* x) const = 0;
67
64
 
68
65
  /// decode it
69
- virtual void decode(uint64_t code, float *c) const = 0;
66
+ virtual void decode(uint64_t code, float* c) const = 0;
70
67
 
71
68
  // call encode on nc vectors
72
- void encode_multi (size_t nc, const float *c,
73
- uint64_t * codes) const;
69
+ void encode_multi(size_t nc, const float* c, uint64_t* codes) const;
74
70
 
75
71
  // call decode on nc codes
76
- void decode_multi (size_t nc, const uint64_t * codes,
77
- float *c) const;
72
+ void decode_multi(size_t nc, const uint64_t* codes, float* c) const;
78
73
 
79
74
  // find the nearest neighbor of each xq
80
75
  // (decodes and computes distances)
81
- void find_nn (size_t n, const uint64_t * codes,
82
- size_t nq, const float *xq,
83
- int64_t *idx, float *dis);
76
+ void find_nn(
77
+ size_t n,
78
+ const uint64_t* codes,
79
+ size_t nq,
80
+ const float* xq,
81
+ int64_t* idx,
82
+ float* dis);
84
83
 
85
84
  virtual ~EnumeratedVectors() {}
86
-
87
85
  };
88
86
 
89
87
  struct Repeat {
@@ -100,26 +98,24 @@ struct Repeats {
100
98
  std::vector<Repeat> repeats;
101
99
 
102
100
  // initialize from a template of the atom.
103
- Repeats(int dim = 0, const float *c = nullptr);
101
+ Repeats(int dim = 0, const float* c = nullptr);
104
102
 
105
103
  // count number of possible codes for this atom
106
104
  uint64_t count() const;
107
105
 
108
- uint64_t encode(const float *c) const;
106
+ uint64_t encode(const float* c) const;
109
107
 
110
- void decode(uint64_t code, float *c) const;
108
+ void decode(uint64_t code, float* c) const;
111
109
  };
112
110
 
113
-
114
111
  /** codec that can return ids for the encoded vectors
115
112
  *
116
113
  * uses the ZnSphereSearch to encode the vector by encoding the
117
114
  * permutation and signs. Depends on ZnSphereSearch because it uses
118
115
  * the atom numbers */
119
- struct ZnSphereCodec: ZnSphereSearch, EnumeratedVectors {
120
-
121
- struct CodeSegment:Repeats {
122
- explicit CodeSegment(const Repeats & r): Repeats(r) {}
116
+ struct ZnSphereCodec : ZnSphereSearch, EnumeratedVectors {
117
+ struct CodeSegment : Repeats {
118
+ explicit CodeSegment(const Repeats& r) : Repeats(r) {}
123
119
  uint64_t c0; // first code assigned to segment
124
120
  int signbits;
125
121
  };
@@ -130,13 +126,12 @@ struct ZnSphereCodec: ZnSphereSearch, EnumeratedVectors {
130
126
 
131
127
  ZnSphereCodec(int dim, int r2);
132
128
 
133
- uint64_t search_and_encode(const float *x) const;
129
+ uint64_t search_and_encode(const float* x) const;
134
130
 
135
- void decode(uint64_t code, float *c) const override;
131
+ void decode(uint64_t code, float* c) const override;
136
132
 
137
133
  /// takes vectors that do not need to be centroids
138
- uint64_t encode(const float *x) const override;
139
-
134
+ uint64_t encode(const float* x) const override;
140
135
  };
141
136
 
142
137
  /** recursive sphere codec
@@ -145,8 +140,7 @@ struct ZnSphereCodec: ZnSphereSearch, EnumeratedVectors {
145
140
  * centroids found by the ZnSphereSearch. The codes are *not*
146
141
  * compatible with the ones of ZnSpehreCodec
147
142
  */
148
- struct ZnSphereCodecRec: EnumeratedVectors {
149
-
143
+ struct ZnSphereCodecRec : EnumeratedVectors {
150
144
  int r2;
151
145
 
152
146
  int log2_dim;
@@ -154,19 +148,19 @@ struct ZnSphereCodecRec: EnumeratedVectors {
154
148
 
155
149
  ZnSphereCodecRec(int dim, int r2);
156
150
 
157
- uint64_t encode_centroid(const float *c) const;
151
+ uint64_t encode_centroid(const float* c) const;
158
152
 
159
- void decode(uint64_t code, float *c) const override;
153
+ void decode(uint64_t code, float* c) const override;
160
154
 
161
155
  /// vectors need to be centroids (does not work on arbitrary
162
156
  /// vectors)
163
- uint64_t encode(const float *x) const override;
157
+ uint64_t encode(const float* x) const override;
164
158
 
165
159
  std::vector<uint64_t> all_nv;
166
160
  std::vector<uint64_t> all_nv_cum;
167
161
 
168
162
  int decode_cache_ld;
169
- std::vector<std::vector<float> > decode_cache;
163
+ std::vector<std::vector<float>> decode_cache;
170
164
 
171
165
  // nb of vectors in the sphere in dim 2^ld with r2 radius
172
166
  uint64_t get_nv(int ld, int r2a) const;
@@ -174,26 +168,21 @@ struct ZnSphereCodecRec: EnumeratedVectors {
174
168
  // cumulative version
175
169
  uint64_t get_nv_cum(int ld, int r2t, int r2a) const;
176
170
  void set_nv_cum(int ld, int r2t, int r2a, uint64_t v);
177
-
178
171
  };
179
172
 
180
-
181
173
  /** Codec that uses the recursive codec if dim is a power of 2 and
182
174
  * the regular one otherwise */
183
- struct ZnSphereCodecAlt: ZnSphereCodec {
175
+ struct ZnSphereCodecAlt : ZnSphereCodec {
184
176
  bool use_rec;
185
177
  ZnSphereCodecRec znc_rec;
186
178
 
187
- ZnSphereCodecAlt (int dim, int r2);
188
-
189
- uint64_t encode(const float *x) const override;
190
-
191
- void decode(uint64_t code, float *c) const override;
192
-
193
- };
179
+ ZnSphereCodecAlt(int dim, int r2);
194
180
 
181
+ uint64_t encode(const float* x) const override;
195
182
 
183
+ void decode(uint64_t code, float* c) const override;
196
184
  };
197
185
 
186
+ } // namespace faiss
198
187
 
199
188
  #endif
@@ -7,14 +7,12 @@
7
7
 
8
8
  #pragma once
9
9
 
10
-
11
10
  #ifdef _MSC_VER
12
11
 
13
12
  /*******************************************************
14
13
  * Windows specific macros
15
14
  *******************************************************/
16
15
 
17
-
18
16
  #ifdef FAISS_MAIN_LIB
19
17
  #define FAISS_API __declspec(dllexport)
20
18
  #else // _FAISS_MAIN_LIB
@@ -23,7 +21,8 @@
23
21
 
24
22
  #define __PRETTY_FUNCTION__ __FUNCSIG__
25
23
 
26
- #define posix_memalign(p, a, s) (((*(p)) = _aligned_malloc((s), (a))), *(p) ?0 :errno)
24
+ #define posix_memalign(p, a, s) \
25
+ (((*(p)) = _aligned_malloc((s), (a))), *(p) ? 0 : errno)
27
26
  #define posix_memalign_free _aligned_free
28
27
 
29
28
  // aligned should be in front of the declaration
@@ -39,18 +38,42 @@ inline int __builtin_ctzll(uint64_t x) {
39
38
  return (int)ret;
40
39
  }
41
40
 
41
+ // cudatoolkit provides __builtin_ctz for NVCC >= 11.0
42
+ #if !defined(__CUDACC__) || __CUDACC_VER_MAJOR__ < 11
42
43
  inline int __builtin_ctz(unsigned long x) {
43
44
  unsigned long ret;
44
45
  _BitScanForward(&ret, x);
45
46
  return (int)ret;
46
47
  }
48
+ #endif
47
49
 
48
50
  inline int __builtin_clzll(uint64_t x) {
49
51
  return (int)__lzcnt64(x);
50
52
  }
51
53
 
54
+ #define __builtin_popcount __popcnt
52
55
  #define __builtin_popcountl __popcnt64
53
56
 
57
+ // MSVC does not define __SSEx__, and _M_IX86_FP is only defined on 32-bit
58
+ // processors cf.
59
+ // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macros
60
+ #ifdef __AVX__
61
+ #define __SSE__ 1
62
+ #define __SSE2__ 1
63
+ #define __SSE3__ 1
64
+ #define __SSE4_1__ 1
65
+ #define __SSE4_2__ 1
66
+ #endif
67
+
68
+ // MSVC sets FMA and F16C automatically when using AVX2
69
+ // Ref. FMA (under /arch:AVX2):
70
+ // https://docs.microsoft.com/en-us/cpp/build/reference/arch-x64 Ref. F16C (2nd
71
+ // paragraph): https://walbourn.github.io/directxmath-avx2/
72
+ #ifdef __AVX2__
73
+ #define __FMA__ 1
74
+ #define __F16C__ 1
75
+ #endif
76
+
54
77
  #else
55
78
  /*******************************************************
56
79
  * Linux and OSX
@@ -59,10 +82,8 @@ inline int __builtin_clzll(uint64_t x) {
59
82
  #define FAISS_API
60
83
  #define posix_memalign_free free
61
84
 
62
- // aligned should be *in front* of the declaration, for compatibility with windows
63
- #define ALIGNED(x) __attribute__ ((aligned(x)))
85
+ // aligned should be *in front* of the declaration, for compatibility with
86
+ // windows
87
+ #define ALIGNED(x) __attribute__((aligned(x)))
64
88
 
65
89
  #endif // _MSC_VER
66
-
67
-
68
-
@@ -5,37 +5,33 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- #include <faiss/impl/pq4_fast_scan.h>
9
8
  #include <faiss/impl/FaissAssert.h>
9
+ #include <faiss/impl/pq4_fast_scan.h>
10
10
  #include <faiss/impl/simd_result_handlers.h>
11
11
 
12
12
  #include <array>
13
13
 
14
-
15
14
  namespace faiss {
16
15
 
17
-
18
16
  using namespace simd_result_handlers;
19
17
 
20
-
21
-
22
18
  /***************************************************************
23
19
  * Packing functions for codes
24
20
  ***************************************************************/
25
21
 
26
-
27
-
28
22
  namespace {
29
23
 
30
24
  /* extract the column starting at (i, j)
31
25
  * from packed matrix src of size (m, n)*/
32
- template<typename T, class TA>
26
+ template <typename T, class TA>
33
27
  void get_matrix_column(
34
- T * src,
35
- size_t m, size_t n,
36
- int64_t i, int64_t j,
37
- TA & dest) {
38
- for(int64_t k = 0; k < dest.size(); k++) {
28
+ T* src,
29
+ size_t m,
30
+ size_t n,
31
+ int64_t i,
32
+ int64_t j,
33
+ TA& dest) {
34
+ for (int64_t k = 0; k < dest.size(); k++) {
39
35
  if (k + i >= 0 && k + i < m) {
40
36
  dest[k] = src[(k + i) * n + j];
41
37
  } else {
@@ -46,38 +42,34 @@ void get_matrix_column(
46
42
 
47
43
  } // anonymous namespace
48
44
 
49
-
50
45
  void pq4_pack_codes(
51
- const uint8_t *codes,
52
- size_t ntotal, size_t M,
53
- size_t nb, size_t bbs, size_t nsq,
54
- uint8_t *blocks
55
- )
56
- {
46
+ const uint8_t* codes,
47
+ size_t ntotal,
48
+ size_t M,
49
+ size_t nb,
50
+ size_t bbs,
51
+ size_t nsq,
52
+ uint8_t* blocks) {
57
53
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
58
54
  FAISS_THROW_IF_NOT(nb % bbs == 0);
59
55
  FAISS_THROW_IF_NOT(nsq % 2 == 0);
60
56
 
61
57
  memset(blocks, 0, nb * nsq / 2);
62
- const uint8_t perm0[16] =
63
- {0, 8, 1, 9, 2, 10, 3, 11,
64
- 4, 12, 5, 13, 6, 14, 7, 15};
58
+ const uint8_t perm0[16] = {
59
+ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
65
60
 
66
- uint8_t *codes2 = blocks;
67
- for(size_t i0 = 0; i0 < nb; i0 += bbs) {
68
- for(int sq = 0; sq < nsq; sq += 2) {
69
- for(size_t i = 0; i < bbs; i += 32) {
61
+ uint8_t* codes2 = blocks;
62
+ for (size_t i0 = 0; i0 < nb; i0 += bbs) {
63
+ for (int sq = 0; sq < nsq; sq += 2) {
64
+ for (size_t i = 0; i < bbs; i += 32) {
70
65
  std::array<uint8_t, 32> c, c0, c1;
71
66
  get_matrix_column(
72
- codes, ntotal,
73
- (M + 1) / 2,
74
- i0 + i, sq / 2, c
75
- );
76
- for(int j = 0; j < 32; j++) {
67
+ codes, ntotal, (M + 1) / 2, i0 + i, sq / 2, c);
68
+ for (int j = 0; j < 32; j++) {
77
69
  c0[j] = c[j] & 15;
78
70
  c1[j] = c[j] >> 4;
79
71
  }
80
- for(int j = 0; j < 16; j++) {
72
+ for (int j = 0; j < 16; j++) {
81
73
  uint8_t d0, d1;
82
74
  d0 = c0[perm0[j]] | (c0[perm0[j] + 16] << 4);
83
75
  d1 = c1[perm0[j]] | (c1[perm0[j] + 16] << 4);
@@ -91,36 +83,33 @@ void pq4_pack_codes(
91
83
  }
92
84
 
93
85
  void pq4_pack_codes_range(
94
- const uint8_t *codes,
86
+ const uint8_t* codes,
95
87
  size_t M,
96
- size_t i0, size_t i1,
97
- size_t bbs, size_t M2,
98
- uint8_t * blocks
99
- ) {
100
- const uint8_t perm0[16] =
101
- {0, 8, 1, 9, 2, 10, 3, 11,
102
- 4, 12, 5, 13, 6, 14, 7, 15};
88
+ size_t i0,
89
+ size_t i1,
90
+ size_t bbs,
91
+ size_t M2,
92
+ uint8_t* blocks) {
93
+ const uint8_t perm0[16] = {
94
+ 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
103
95
 
104
96
  // range of affected blocks
105
97
  size_t block0 = i0 / bbs;
106
98
  size_t block1 = ((i1 - 1) / bbs) + 1;
107
99
 
108
100
  for (size_t b = block0; b < block1; b++) {
109
- uint8_t *codes2 = blocks + b * bbs * M2 / 2;
101
+ uint8_t* codes2 = blocks + b * bbs * M2 / 2;
110
102
  int64_t i_base = b * bbs - i0;
111
- for(int sq = 0; sq < M2; sq += 2) {
112
- for(size_t i = 0; i < bbs; i += 32) {
103
+ for (int sq = 0; sq < M2; sq += 2) {
104
+ for (size_t i = 0; i < bbs; i += 32) {
113
105
  std::array<uint8_t, 32> c, c0, c1;
114
106
  get_matrix_column(
115
- codes, i1 - i0,
116
- (M + 1) / 2,
117
- i_base + i, sq / 2, c
118
- );
119
- for(int j = 0; j < 32; j++) {
107
+ codes, i1 - i0, (M + 1) / 2, i_base + i, sq / 2, c);
108
+ for (int j = 0; j < 32; j++) {
120
109
  c0[j] = c[j] & 15;
121
110
  c1[j] = c[j] >> 4;
122
111
  }
123
- for(int j = 0; j < 16; j++) {
112
+ for (int j = 0; j < 16; j++) {
124
113
  uint8_t d0, d1;
125
114
  d0 = c0[perm0[j]] | (c0[perm0[j] + 16] << 4);
126
115
  d1 = c1[perm0[j]] | (c1[perm0[j] + 16] << 4);
@@ -131,14 +120,14 @@ void pq4_pack_codes_range(
131
120
  }
132
121
  }
133
122
  }
134
-
135
123
  }
136
124
 
137
-
138
125
  uint8_t pq4_get_packed_element(
139
- const uint8_t *data, size_t bbs, size_t nsq,
140
- size_t i, size_t sq
141
- ) {
126
+ const uint8_t* data,
127
+ size_t bbs,
128
+ size_t nsq,
129
+ size_t i,
130
+ size_t sq) {
142
131
  // move to correct bbs-sized block
143
132
  data += (i / bbs * (nsq / 2) + sq / 2) * bbs;
144
133
  sq = sq & 1;
@@ -151,122 +140,86 @@ uint8_t pq4_get_packed_element(
151
140
  if (sq == 1) {
152
141
  data += 16;
153
142
  }
154
- const uint8_t iperm0[16] =
155
- {0, 2, 4, 6, 8, 10, 12, 14,
156
- 1, 3, 5, 7, 9, 11, 13, 15};
143
+ const uint8_t iperm0[16] = {
144
+ 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
157
145
  if (i < 16) {
158
146
  return data[iperm0[i]] & 15;
159
147
  } else {
160
148
  return data[iperm0[i - 16]] >> 4;
161
149
  }
162
-
163
150
  }
164
151
 
165
152
  /***************************************************************
166
153
  * Packing functions for Look-Up Tables (LUT)
167
154
  ***************************************************************/
168
155
 
169
-
170
-
171
-
172
- void pq4_pack_LUT(
173
- int nq, int nsq,
174
- const uint8_t *src,
175
- uint8_t *dest)
176
- {
177
-
178
- for(int q = 0; q < nq; q++) {
179
- for(int sq = 0; sq < nsq; sq += 2) {
180
- memcpy(
181
- dest + (sq / 2 * nq + q) * 32,
182
- src + (q * nsq + sq) * 16,
183
- 16
184
- );
185
- memcpy(
186
- dest + (sq / 2 * nq + q) * 32 + 16,
187
- src + (q * nsq + sq + 1) * 16,
188
- 16
189
- );
156
+ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest) {
157
+ for (int q = 0; q < nq; q++) {
158
+ for (int sq = 0; sq < nsq; sq += 2) {
159
+ memcpy(dest + (sq / 2 * nq + q) * 32,
160
+ src + (q * nsq + sq) * 16,
161
+ 16);
162
+ memcpy(dest + (sq / 2 * nq + q) * 32 + 16,
163
+ src + (q * nsq + sq + 1) * 16,
164
+ 16);
190
165
  }
191
166
  }
192
167
  }
193
168
 
194
-
195
- int pq4_pack_LUT_qbs(
196
- int qbs, int nsq,
197
- const uint8_t *src,
198
- uint8_t *dest)
199
- {
169
+ int pq4_pack_LUT_qbs(int qbs, int nsq, const uint8_t* src, uint8_t* dest) {
200
170
  FAISS_THROW_IF_NOT(nsq % 2 == 0);
201
171
  size_t dim12 = 16 * nsq;
202
172
  int i0 = 0;
203
173
  int qi = qbs;
204
- while(qi) {
174
+ while (qi) {
205
175
  int nq = qi & 15;
206
176
  qi >>= 4;
207
- pq4_pack_LUT(
208
- nq, nsq,
209
- src + i0 * dim12,
210
- dest + i0 * dim12
211
- );
177
+ pq4_pack_LUT(nq, nsq, src + i0 * dim12, dest + i0 * dim12);
212
178
  i0 += nq;
213
179
  }
214
180
  return i0;
215
181
  }
216
182
 
217
-
218
183
  namespace {
219
184
 
220
185
  void pack_LUT_1_q_map(
221
- int nq, const int *q_map,
186
+ int nq,
187
+ const int* q_map,
222
188
  int nsq,
223
- const uint8_t *src,
224
- uint8_t *dest)
225
- {
226
-
227
- for(int qi = 0; qi < nq; qi++) {
189
+ const uint8_t* src,
190
+ uint8_t* dest) {
191
+ for (int qi = 0; qi < nq; qi++) {
228
192
  int q = q_map[qi];
229
- for(int sq = 0; sq < nsq; sq += 2) {
230
- memcpy(
231
- dest + (sq / 2 * nq + qi) * 32,
232
- src + (q * nsq + sq) * 16,
233
- 16
234
- );
235
- memcpy(
236
- dest + (sq / 2 * nq + qi) * 32 + 16,
237
- src + (q * nsq + sq + 1) * 16,
238
- 16
239
- );
193
+ for (int sq = 0; sq < nsq; sq += 2) {
194
+ memcpy(dest + (sq / 2 * nq + qi) * 32,
195
+ src + (q * nsq + sq) * 16,
196
+ 16);
197
+ memcpy(dest + (sq / 2 * nq + qi) * 32 + 16,
198
+ src + (q * nsq + sq + 1) * 16,
199
+ 16);
240
200
  }
241
201
  }
242
-
243
202
  }
244
203
 
245
204
  } // anonymous namespace
246
205
 
247
206
  int pq4_pack_LUT_qbs_q_map(
248
- int qbs, int nsq,
249
- const uint8_t *src,
250
- const int * q_map,
251
- uint8_t *dest)
252
- {
207
+ int qbs,
208
+ int nsq,
209
+ const uint8_t* src,
210
+ const int* q_map,
211
+ uint8_t* dest) {
253
212
  FAISS_THROW_IF_NOT(nsq % 2 == 0);
254
213
  size_t dim12 = 16 * nsq;
255
214
  int i0 = 0;
256
215
  int qi = qbs;
257
- while(qi) {
216
+ while (qi) {
258
217
  int nq = qi & 15;
259
218
  qi >>= 4;
260
- pack_LUT_1_q_map(
261
- nq, q_map + i0, nsq,
262
- src,
263
- dest + i0 * dim12
264
- );
219
+ pack_LUT_1_q_map(nq, q_map + i0, nsq, src, dest + i0 * dim12);
265
220
  i0 += nq;
266
221
  }
267
222
  return i0;
268
223
  }
269
224
 
270
-
271
-
272
225
  } // namespace faiss