faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -10,60 +10,58 @@
10
10
  #ifndef FAISS_INDEX_IVF_FLAT_H
11
11
  #define FAISS_INDEX_IVF_FLAT_H
12
12
 
13
- #include <unordered_map>
14
13
  #include <stdint.h>
14
+ #include <unordered_map>
15
15
 
16
16
  #include <faiss/IndexIVF.h>
17
17
 
18
-
19
18
  namespace faiss {
20
19
 
21
20
  /** Inverted file with stored vectors. Here the inverted file
22
21
  * pre-selects the vectors to be searched, but they are not otherwise
23
22
  * encoded, the code array just contains the raw float entries.
24
23
  */
25
- struct IndexIVFFlat: IndexIVF {
26
-
27
- IndexIVFFlat (
28
- Index * quantizer, size_t d, size_t nlist_,
24
+ struct IndexIVFFlat : IndexIVF {
25
+ IndexIVFFlat(
26
+ Index* quantizer,
27
+ size_t d,
28
+ size_t nlist_,
29
29
  MetricType = METRIC_L2);
30
30
 
31
- /// same as add_with_ids, with precomputed coarse quantizer
32
- virtual void add_core (idx_t n, const float * x, const int64_t *xids,
33
- const int64_t *precomputed_idx);
34
-
35
- /// implemented for all IndexIVF* classes
36
- void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
37
-
38
- void encode_vectors(idx_t n, const float* x,
39
- const idx_t *list_nos,
40
- uint8_t * codes,
41
- bool include_listnos=false) const override;
31
+ void add_core(
32
+ idx_t n,
33
+ const float* x,
34
+ const idx_t* xids,
35
+ const idx_t* precomputed_idx) override;
42
36
 
37
+ void encode_vectors(
38
+ idx_t n,
39
+ const float* x,
40
+ const idx_t* list_nos,
41
+ uint8_t* codes,
42
+ bool include_listnos = false) const override;
43
43
 
44
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
45
- const override;
44
+ InvertedListScanner* get_InvertedListScanner(
45
+ bool store_pairs) const override;
46
46
 
47
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
48
+ const override;
47
49
 
48
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
49
- float* recons) const override;
50
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
50
51
 
51
- void sa_decode (idx_t n, const uint8_t *bytes,
52
- float *x) const override;
53
-
54
- IndexIVFFlat () {}
52
+ IndexIVFFlat() {}
55
53
  };
56
54
 
57
-
58
- struct IndexIVFFlatDedup: IndexIVFFlat {
59
-
55
+ struct IndexIVFFlatDedup : IndexIVFFlat {
60
56
  /** Maps ids stored in the index to the ids of vectors that are
61
57
  * the same. When a vector is unique, it does not appear in the
62
58
  * instances map */
63
- std::unordered_multimap <idx_t, idx_t> instances;
59
+ std::unordered_multimap<idx_t, idx_t> instances;
64
60
 
65
- IndexIVFFlatDedup (
66
- Index * quantizer, size_t d, size_t nlist_,
61
+ IndexIVFFlatDedup(
62
+ Index* quantizer,
63
+ size_t d,
64
+ size_t nlist_,
67
65
  MetricType = METRIC_L2);
68
66
 
69
67
  /// also dedups the training set
@@ -72,38 +70,37 @@ struct IndexIVFFlatDedup: IndexIVFFlat {
72
70
  /// implemented for all IndexIVF* classes
73
71
  void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
74
72
 
75
- void search_preassigned (idx_t n, const float *x, idx_t k,
76
- const idx_t *assign,
77
- const float *centroid_dis,
78
- float *distances, idx_t *labels,
79
- bool store_pairs,
80
- const IVFSearchParameters *params=nullptr,
81
- IndexIVFStats *stats=nullptr
82
- ) const override;
73
+ void search_preassigned(
74
+ idx_t n,
75
+ const float* x,
76
+ idx_t k,
77
+ const idx_t* assign,
78
+ const float* centroid_dis,
79
+ float* distances,
80
+ idx_t* labels,
81
+ bool store_pairs,
82
+ const IVFSearchParameters* params = nullptr,
83
+ IndexIVFStats* stats = nullptr) const override;
83
84
 
84
85
  size_t remove_ids(const IDSelector& sel) override;
85
86
 
86
87
  /// not implemented
87
88
  void range_search(
88
- idx_t n,
89
- const float* x,
90
- float radius,
91
- RangeSearchResult* result) const override;
89
+ idx_t n,
90
+ const float* x,
91
+ float radius,
92
+ RangeSearchResult* result) const override;
92
93
 
93
94
  /// not implemented
94
- void update_vectors (int nv, const idx_t *idx, const float *v) override;
95
+ void update_vectors(int nv, const idx_t* idx, const float* v) override;
95
96
 
96
97
  /// not implemented
97
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
98
- float* recons) const override;
99
-
100
- IndexIVFFlatDedup () {}
101
-
98
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
99
+ const override;
102
100
 
101
+ IndexIVFFlatDedup() {}
103
102
  };
104
103
 
105
-
106
-
107
104
  } // namespace faiss
108
105
 
109
106
  #endif
@@ -9,17 +9,17 @@
9
9
 
10
10
  #include <faiss/IndexIVFPQ.h>
11
11
 
12
+ #include <stdint.h>
13
+ #include <cassert>
12
14
  #include <cinttypes>
13
15
  #include <cmath>
14
16
  #include <cstdio>
15
- #include <cassert>
16
- #include <stdint.h>
17
17
 
18
18
  #include <algorithm>
19
19
 
20
20
  #include <faiss/utils/Heap.h>
21
- #include <faiss/utils/utils.h>
22
21
  #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/utils.h>
23
23
 
24
24
  #include <faiss/Clustering.h>
25
25
  #include <faiss/IndexFlat.h>
@@ -36,12 +36,15 @@ namespace faiss {
36
36
  * IndexIVFPQ implementation
37
37
  ******************************************/
38
38
 
39
- IndexIVFPQ::IndexIVFPQ (Index * quantizer, size_t d, size_t nlist,
40
- size_t M, size_t nbits_per_idx, MetricType metric):
41
- IndexIVF (quantizer, d, nlist, 0, metric),
42
- pq (d, M, nbits_per_idx)
43
- {
44
- FAISS_THROW_IF_NOT (nbits_per_idx <= 8);
39
+ IndexIVFPQ::IndexIVFPQ(
40
+ Index* quantizer,
41
+ size_t d,
42
+ size_t nlist,
43
+ size_t M,
44
+ size_t nbits_per_idx,
45
+ MetricType metric)
46
+ : IndexIVF(quantizer, d, nlist, 0, metric), pq(d, M, nbits_per_idx) {
47
+ FAISS_THROW_IF_NOT(nbits_per_idx <= 8);
45
48
  code_size = pq.code_size;
46
49
  invlists->code_size = code_size;
47
50
  is_trained = false;
@@ -52,202 +55,197 @@ IndexIVFPQ::IndexIVFPQ (Index * quantizer, size_t d, size_t nlist,
52
55
  polysemous_training = nullptr;
53
56
  do_polysemous_training = false;
54
57
  polysemous_ht = 0;
55
-
56
58
  }
57
59
 
58
-
59
60
  /****************************************************************
60
61
  * training */
61
62
 
62
- void IndexIVFPQ::train_residual (idx_t n, const float *x)
63
- {
64
- train_residual_o (n, x, nullptr);
63
+ void IndexIVFPQ::train_residual(idx_t n, const float* x) {
64
+ train_residual_o(n, x, nullptr);
65
65
  }
66
66
 
67
+ void IndexIVFPQ::train_residual_o(idx_t n, const float* x, float* residuals_2) {
68
+ const float* x_in = x;
67
69
 
68
- void IndexIVFPQ::train_residual_o (idx_t n, const float *x, float *residuals_2)
69
- {
70
- const float * x_in = x;
71
-
72
- x = fvecs_maybe_subsample (
73
- d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub,
74
- x, verbose, pq.cp.seed);
70
+ x = fvecs_maybe_subsample(
71
+ d,
72
+ (size_t*)&n,
73
+ pq.cp.max_points_per_centroid * pq.ksub,
74
+ x,
75
+ verbose,
76
+ pq.cp.seed);
75
77
 
76
- ScopeDeleter<float> del_x (x_in == x ? nullptr : x);
78
+ ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
77
79
 
78
- const float *trainset;
80
+ const float* trainset;
79
81
  ScopeDeleter<float> del_residuals;
80
82
  if (by_residual) {
81
- if(verbose) printf("computing residuals\n");
82
- idx_t * assign = new idx_t [n]; // assignement to coarse centroids
83
- ScopeDeleter<idx_t> del (assign);
84
- quantizer->assign (n, x, assign);
85
- float *residuals = new float [n * d];
86
- del_residuals.set (residuals);
83
+ if (verbose)
84
+ printf("computing residuals\n");
85
+ idx_t* assign = new idx_t[n]; // assignement to coarse centroids
86
+ ScopeDeleter<idx_t> del(assign);
87
+ quantizer->assign(n, x, assign);
88
+ float* residuals = new float[n * d];
89
+ del_residuals.set(residuals);
87
90
  for (idx_t i = 0; i < n; i++)
88
- quantizer->compute_residual (x + i * d, residuals+i*d, assign[i]);
91
+ quantizer->compute_residual(
92
+ x + i * d, residuals + i * d, assign[i]);
89
93
 
90
94
  trainset = residuals;
91
95
  } else {
92
96
  trainset = x;
93
97
  }
94
98
  if (verbose)
95
- printf ("training %zdx%zd product quantizer on %" PRId64 " vectors in %dD\n",
96
- pq.M, pq.ksub, n, d);
99
+ printf("training %zdx%zd product quantizer on %" PRId64
100
+ " vectors in %dD\n",
101
+ pq.M,
102
+ pq.ksub,
103
+ n,
104
+ d);
97
105
  pq.verbose = verbose;
98
- pq.train (n, trainset);
106
+ pq.train(n, trainset);
99
107
 
100
108
  if (do_polysemous_training) {
101
109
  if (verbose)
102
110
  printf("doing polysemous training for PQ\n");
103
111
  PolysemousTraining default_pt;
104
- PolysemousTraining *pt = polysemous_training;
105
- if (!pt) pt = &default_pt;
106
- pt->optimize_pq_for_hamming (pq, n, trainset);
112
+ PolysemousTraining* pt = polysemous_training;
113
+ if (!pt)
114
+ pt = &default_pt;
115
+ pt->optimize_pq_for_hamming(pq, n, trainset);
107
116
  }
108
117
 
109
118
  // prepare second-level residuals for refine PQ
110
119
  if (residuals_2) {
111
- uint8_t *train_codes = new uint8_t [pq.code_size * n];
112
- ScopeDeleter<uint8_t> del (train_codes);
113
- pq.compute_codes (trainset, train_codes, n);
120
+ uint8_t* train_codes = new uint8_t[pq.code_size * n];
121
+ ScopeDeleter<uint8_t> del(train_codes);
122
+ pq.compute_codes(trainset, train_codes, n);
114
123
 
115
124
  for (idx_t i = 0; i < n; i++) {
116
- const float *xx = trainset + i * d;
117
- float * res = residuals_2 + i * d;
118
- pq.decode (train_codes + i * pq.code_size, res);
125
+ const float* xx = trainset + i * d;
126
+ float* res = residuals_2 + i * d;
127
+ pq.decode(train_codes + i * pq.code_size, res);
119
128
  for (int j = 0; j < d; j++)
120
129
  res[j] = xx[j] - res[j];
121
130
  }
122
-
123
131
  }
124
132
 
125
133
  if (by_residual) {
126
- precompute_table ();
134
+ precompute_table();
127
135
  }
128
-
129
136
  }
130
137
 
131
-
132
-
133
-
134
-
135
-
136
138
  /****************************************************************
137
139
  * IVFPQ as codec */
138
140
 
139
-
140
141
  /* produce a binary signature based on the residual vector */
141
- void IndexIVFPQ::encode (idx_t key, const float * x, uint8_t * code) const
142
- {
142
+ void IndexIVFPQ::encode(idx_t key, const float* x, uint8_t* code) const {
143
143
  if (by_residual) {
144
144
  std::vector<float> residual_vec(d);
145
- quantizer->compute_residual (x, residual_vec.data(), key);
146
- pq.compute_code (residual_vec.data(), code);
147
- }
148
- else pq.compute_code (x, code);
145
+ quantizer->compute_residual(x, residual_vec.data(), key);
146
+ pq.compute_code(residual_vec.data(), code);
147
+ } else
148
+ pq.compute_code(x, code);
149
149
  }
150
150
 
151
- void IndexIVFPQ::encode_multiple (size_t n, idx_t *keys,
152
- const float * x, uint8_t * xcodes,
153
- bool compute_keys) const
154
- {
151
+ void IndexIVFPQ::encode_multiple(
152
+ size_t n,
153
+ idx_t* keys,
154
+ const float* x,
155
+ uint8_t* xcodes,
156
+ bool compute_keys) const {
155
157
  if (compute_keys)
156
- quantizer->assign (n, x, keys);
158
+ quantizer->assign(n, x, keys);
157
159
 
158
- encode_vectors (n, x, keys, xcodes);
160
+ encode_vectors(n, x, keys, xcodes);
159
161
  }
160
162
 
161
- void IndexIVFPQ::decode_multiple (size_t n, const idx_t *keys,
162
- const uint8_t * xcodes, float * x) const
163
- {
164
- pq.decode (xcodes, x, n);
163
+ void IndexIVFPQ::decode_multiple(
164
+ size_t n,
165
+ const idx_t* keys,
166
+ const uint8_t* xcodes,
167
+ float* x) const {
168
+ pq.decode(xcodes, x, n);
165
169
  if (by_residual) {
166
- std::vector<float> centroid (d);
170
+ std::vector<float> centroid(d);
167
171
  for (size_t i = 0; i < n; i++) {
168
- quantizer->reconstruct (keys[i], centroid.data());
169
- float *xi = x + i * d;
172
+ quantizer->reconstruct(keys[i], centroid.data());
173
+ float* xi = x + i * d;
170
174
  for (size_t j = 0; j < d; j++) {
171
- xi [j] += centroid [j];
175
+ xi[j] += centroid[j];
172
176
  }
173
177
  }
174
178
  }
175
179
  }
176
180
 
177
-
178
-
179
-
180
181
  /****************************************************************
181
182
  * add */
182
183
 
183
-
184
- void IndexIVFPQ::add_with_ids (idx_t n, const float * x, const idx_t *xids)
185
- {
186
- add_core_o (n, x, xids, nullptr);
184
+ void IndexIVFPQ::add_core(
185
+ idx_t n,
186
+ const float* x,
187
+ const idx_t* xids,
188
+ const idx_t* coarse_idx) {
189
+ add_core_o(n, x, xids, nullptr, coarse_idx);
187
190
  }
188
191
 
189
-
190
- static float * compute_residuals (
191
- const Index *quantizer,
192
- Index::idx_t n, const float* x,
193
- const Index::idx_t *list_nos)
194
- {
192
+ static float* compute_residuals(
193
+ const Index* quantizer,
194
+ Index::idx_t n,
195
+ const float* x,
196
+ const Index::idx_t* list_nos) {
195
197
  size_t d = quantizer->d;
196
- float *residuals = new float [n * d];
198
+ float* residuals = new float[n * d];
197
199
  // TODO: parallelize?
198
200
  for (size_t i = 0; i < n; i++) {
199
201
  if (list_nos[i] < 0)
200
- memset (residuals + i * d, 0, sizeof(*residuals) * d);
202
+ memset(residuals + i * d, 0, sizeof(*residuals) * d);
201
203
  else
202
- quantizer->compute_residual (
203
- x + i * d, residuals + i * d, list_nos[i]);
204
+ quantizer->compute_residual(
205
+ x + i * d, residuals + i * d, list_nos[i]);
204
206
  }
205
207
  return residuals;
206
208
  }
207
209
 
208
- void IndexIVFPQ::encode_vectors(idx_t n, const float* x,
209
- const idx_t *list_nos,
210
- uint8_t * codes,
211
- bool include_listnos) const
212
- {
210
+ void IndexIVFPQ::encode_vectors(
211
+ idx_t n,
212
+ const float* x,
213
+ const idx_t* list_nos,
214
+ uint8_t* codes,
215
+ bool include_listnos) const {
213
216
  if (by_residual) {
214
- float *to_encode = compute_residuals (quantizer, n, x, list_nos);
215
- ScopeDeleter<float> del (to_encode);
216
- pq.compute_codes (to_encode, codes, n);
217
+ float* to_encode = compute_residuals(quantizer, n, x, list_nos);
218
+ ScopeDeleter<float> del(to_encode);
219
+ pq.compute_codes(to_encode, codes, n);
217
220
  } else {
218
- pq.compute_codes (x, codes, n);
221
+ pq.compute_codes(x, codes, n);
219
222
  }
220
223
 
221
224
  if (include_listnos) {
222
225
  size_t coarse_size = coarse_code_size();
223
226
  for (idx_t i = n - 1; i >= 0; i--) {
224
- uint8_t * code = codes + i * (coarse_size + code_size);
225
- memmove (code + coarse_size,
226
- codes + i * code_size, code_size);
227
- encode_listno (list_nos[i], code);
227
+ uint8_t* code = codes + i * (coarse_size + code_size);
228
+ memmove(code + coarse_size, codes + i * code_size, code_size);
229
+ encode_listno(list_nos[i], code);
228
230
  }
229
231
  }
230
232
  }
231
233
 
232
-
233
-
234
- void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
235
- float *x) const
236
- {
237
- size_t coarse_size = coarse_code_size ();
234
+ void IndexIVFPQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
235
+ size_t coarse_size = coarse_code_size();
238
236
 
239
237
  #pragma omp parallel
240
238
  {
241
- std::vector<float> residual (d);
239
+ std::vector<float> residual(d);
242
240
 
243
241
  #pragma omp for
244
242
  for (idx_t i = 0; i < n; i++) {
245
- const uint8_t *code = codes + i * (code_size + coarse_size);
246
- int64_t list_no = decode_listno (code);
247
- float *xi = x + i * d;
248
- pq.decode (code + coarse_size, xi);
243
+ const uint8_t* code = codes + i * (code_size + coarse_size);
244
+ int64_t list_no = decode_listno(code);
245
+ float* xi = x + i * d;
246
+ pq.decode(code + coarse_size, xi);
249
247
  if (by_residual) {
250
- quantizer->reconstruct (list_no, residual.data());
248
+ quantizer->reconstruct(list_no, residual.data());
251
249
  for (size_t j = 0; j < d; j++) {
252
250
  xi[j] += residual[j];
253
251
  }
@@ -256,120 +254,127 @@ void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
256
254
  }
257
255
  }
258
256
 
259
-
260
- void IndexIVFPQ::add_core_o (idx_t n, const float * x, const idx_t *xids,
261
- float *residuals_2, const idx_t *precomputed_idx)
262
- {
263
-
257
+ void IndexIVFPQ::add_core_o(
258
+ idx_t n,
259
+ const float* x,
260
+ const idx_t* xids,
261
+ float* residuals_2,
262
+ const idx_t* precomputed_idx) {
264
263
  idx_t bs = 32768;
265
264
  if (n > bs) {
266
265
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
267
266
  idx_t i1 = std::min(i0 + bs, n);
268
267
  if (verbose) {
269
- printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64 " / %" PRId64 "\n",
270
- i0, i1, n);
268
+ printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64
269
+ " / %" PRId64 "\n",
270
+ i0,
271
+ i1,
272
+ n);
271
273
  }
272
- add_core_o (i1 - i0, x + i0 * d,
273
- xids ? xids + i0 : nullptr,
274
- residuals_2 ? residuals_2 + i0 * d : nullptr,
275
- precomputed_idx ? precomputed_idx + i0 : nullptr);
274
+ add_core_o(
275
+ i1 - i0,
276
+ x + i0 * d,
277
+ xids ? xids + i0 : nullptr,
278
+ residuals_2 ? residuals_2 + i0 * d : nullptr,
279
+ precomputed_idx ? precomputed_idx + i0 : nullptr);
276
280
  }
277
281
  return;
278
282
  }
279
283
 
280
284
  InterruptCallback::check();
281
285
 
282
- direct_map.check_can_add (xids);
286
+ direct_map.check_can_add(xids);
283
287
 
284
- FAISS_THROW_IF_NOT (is_trained);
285
- double t0 = getmillisecs ();
286
- const idx_t * idx;
288
+ FAISS_THROW_IF_NOT(is_trained);
289
+ double t0 = getmillisecs();
290
+ const idx_t* idx;
287
291
  ScopeDeleter<idx_t> del_idx;
288
292
 
289
293
  if (precomputed_idx) {
290
294
  idx = precomputed_idx;
291
295
  } else {
292
- idx_t * idx0 = new idx_t [n];
293
- del_idx.set (idx0);
294
- quantizer->assign (n, x, idx0);
296
+ idx_t* idx0 = new idx_t[n];
297
+ del_idx.set(idx0);
298
+ quantizer->assign(n, x, idx0);
295
299
  idx = idx0;
296
300
  }
297
301
 
298
- double t1 = getmillisecs ();
299
- uint8_t * xcodes = new uint8_t [n * code_size];
300
- ScopeDeleter<uint8_t> del_xcodes (xcodes);
302
+ double t1 = getmillisecs();
303
+ uint8_t* xcodes = new uint8_t[n * code_size];
304
+ ScopeDeleter<uint8_t> del_xcodes(xcodes);
301
305
 
302
- const float *to_encode = nullptr;
306
+ const float* to_encode = nullptr;
303
307
  ScopeDeleter<float> del_to_encode;
304
308
 
305
309
  if (by_residual) {
306
- to_encode = compute_residuals (quantizer, n, x, idx);
307
- del_to_encode.set (to_encode);
310
+ to_encode = compute_residuals(quantizer, n, x, idx);
311
+ del_to_encode.set(to_encode);
308
312
  } else {
309
313
  to_encode = x;
310
314
  }
311
- pq.compute_codes (to_encode, xcodes, n);
315
+ pq.compute_codes(to_encode, xcodes, n);
312
316
 
313
- double t2 = getmillisecs ();
317
+ double t2 = getmillisecs();
314
318
  // TODO: parallelize?
315
319
  size_t n_ignore = 0;
316
320
  for (size_t i = 0; i < n; i++) {
317
321
  idx_t key = idx[i];
318
322
  idx_t id = xids ? xids[i] : ntotal + i;
319
323
  if (key < 0) {
320
- direct_map.add_single_id (id, -1, 0);
321
- n_ignore ++;
324
+ direct_map.add_single_id(id, -1, 0);
325
+ n_ignore++;
322
326
  if (residuals_2)
323
- memset (residuals_2, 0, sizeof(*residuals_2) * d);
327
+ memset(residuals_2, 0, sizeof(*residuals_2) * d);
324
328
  continue;
325
329
  }
326
330
 
327
- uint8_t *code = xcodes + i * code_size;
328
- size_t offset = invlists->add_entry (key, id, code);
331
+ uint8_t* code = xcodes + i * code_size;
332
+ size_t offset = invlists->add_entry(key, id, code);
329
333
 
330
334
  if (residuals_2) {
331
- float *res2 = residuals_2 + i * d;
332
- const float *xi = to_encode + i * d;
333
- pq.decode (code, res2);
335
+ float* res2 = residuals_2 + i * d;
336
+ const float* xi = to_encode + i * d;
337
+ pq.decode(code, res2);
334
338
  for (int j = 0; j < d; j++)
335
339
  res2[j] = xi[j] - res2[j];
336
340
  }
337
341
 
338
- direct_map.add_single_id (id, key, offset);
342
+ direct_map.add_single_id(id, key, offset);
339
343
  }
340
344
 
341
- double t3 = getmillisecs ();
342
- if(verbose) {
345
+ double t3 = getmillisecs();
346
+ if (verbose) {
343
347
  char comment[100] = {0};
344
348
  if (n_ignore > 0)
345
- snprintf (comment, 100, "(%zd vectors ignored)", n_ignore);
349
+ snprintf(comment, 100, "(%zd vectors ignored)", n_ignore);
346
350
  printf(" add_core times: %.3f %.3f %.3f %s\n",
347
- t1 - t0, t2 - t1, t3 - t2, comment);
351
+ t1 - t0,
352
+ t2 - t1,
353
+ t3 - t2,
354
+ comment);
348
355
  }
349
356
  ntotal += n;
350
357
  }
351
358
 
352
-
353
- void IndexIVFPQ::reconstruct_from_offset (int64_t list_no, int64_t offset,
354
- float* recons) const
355
- {
356
- const uint8_t* code = invlists->get_single_code (list_no, offset);
359
+ void IndexIVFPQ::reconstruct_from_offset(
360
+ int64_t list_no,
361
+ int64_t offset,
362
+ float* recons) const {
363
+ const uint8_t* code = invlists->get_single_code(list_no, offset);
357
364
 
358
365
  if (by_residual) {
359
366
  std::vector<float> centroid(d);
360
- quantizer->reconstruct (list_no, centroid.data());
367
+ quantizer->reconstruct(list_no, centroid.data());
361
368
 
362
- pq.decode (code, recons);
369
+ pq.decode(code, recons);
363
370
  for (int i = 0; i < d; ++i) {
364
371
  recons[i] += centroid[i];
365
372
  }
366
373
  } else {
367
- pq.decode (code, recons);
374
+ pq.decode(code, recons);
368
375
  }
369
376
  }
370
377
 
371
-
372
-
373
378
  /// 2G by default, accommodates tables up to PQ32 w/ 65536 centroids
374
379
  size_t precomputed_table_max_bytes = ((size_t)1) << 31;
375
380
 
@@ -403,20 +408,18 @@ size_t precomputed_table_max_bytes = ((size_t)1) << 31;
403
408
  * is faster when the length of the lists is > ksub * M.
404
409
  */
405
410
 
406
- void initialize_IVFPQ_precomputed_table (
407
- int &use_precomputed_table,
408
- const Index *quantizer,
409
- const ProductQuantizer &pq,
410
- AlignedTable<float> & precomputed_table,
411
- bool verbose
412
- )
413
- {
411
+ void initialize_IVFPQ_precomputed_table(
412
+ int& use_precomputed_table,
413
+ const Index* quantizer,
414
+ const ProductQuantizer& pq,
415
+ AlignedTable<float>& precomputed_table,
416
+ bool verbose) {
414
417
  size_t nlist = quantizer->ntotal;
415
418
  size_t d = quantizer->d;
416
419
  FAISS_THROW_IF_NOT(d == pq.d);
417
420
 
418
421
  if (use_precomputed_table == -1) {
419
- precomputed_table.resize (0);
422
+ precomputed_table.resize(0);
420
423
  return;
421
424
  }
422
425
 
@@ -424,23 +427,23 @@ void initialize_IVFPQ_precomputed_table (
424
427
  if (quantizer->metric_type == METRIC_INNER_PRODUCT) {
425
428
  if (verbose) {
426
429
  printf("IndexIVFPQ::precompute_table: precomputed "
427
- "tables not needed for inner product quantizers\n");
430
+ "tables not needed for inner product quantizers\n");
428
431
  }
429
- precomputed_table.resize (0);
432
+ precomputed_table.resize(0);
430
433
  return;
431
434
  }
432
- const MultiIndexQuantizer *miq =
433
- dynamic_cast<const MultiIndexQuantizer *> (quantizer);
435
+ const MultiIndexQuantizer* miq =
436
+ dynamic_cast<const MultiIndexQuantizer*>(quantizer);
434
437
  if (miq && pq.M % miq->pq.M == 0)
435
438
  use_precomputed_table = 2;
436
439
  else {
437
440
  size_t table_size = pq.M * pq.ksub * nlist * sizeof(float);
438
441
  if (table_size > precomputed_table_max_bytes) {
439
442
  if (verbose) {
440
- printf(
441
- "IndexIVFPQ::precompute_table: not precomputing table, "
442
- "it would be too big: %zd bytes (max %zd)\n",
443
- table_size, precomputed_table_max_bytes);
443
+ printf("IndexIVFPQ::precompute_table: not precomputing table, "
444
+ "it would be too big: %zd bytes (max %zd)\n",
445
+ table_size,
446
+ precomputed_table_max_bytes);
444
447
  use_precomputed_table = 0;
445
448
  }
446
449
  return;
@@ -450,80 +453,68 @@ void initialize_IVFPQ_precomputed_table (
450
453
  } // otherwise assume user has set appropriate flag on input
451
454
 
452
455
  if (verbose) {
453
- printf ("precomputing IVFPQ tables type %d\n",
454
- use_precomputed_table);
456
+ printf("precomputing IVFPQ tables type %d\n", use_precomputed_table);
455
457
  }
456
458
 
457
459
  // squared norms of the PQ centroids
458
- std::vector<float> r_norms (pq.M * pq.ksub, NAN);
460
+ std::vector<float> r_norms(pq.M * pq.ksub, NAN);
459
461
  for (int m = 0; m < pq.M; m++)
460
462
  for (int j = 0; j < pq.ksub; j++)
461
- r_norms [m * pq.ksub + j] =
462
- fvec_norm_L2sqr (pq.get_centroids (m, j), pq.dsub);
463
+ r_norms[m * pq.ksub + j] =
464
+ fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
463
465
 
464
466
  if (use_precomputed_table == 1) {
465
-
466
- precomputed_table.resize (nlist * pq.M * pq.ksub);
467
- std::vector<float> centroid (d);
467
+ precomputed_table.resize(nlist * pq.M * pq.ksub);
468
+ std::vector<float> centroid(d);
468
469
 
469
470
  for (size_t i = 0; i < nlist; i++) {
470
- quantizer->reconstruct (i, centroid.data());
471
+ quantizer->reconstruct(i, centroid.data());
471
472
 
472
- float *tab = &precomputed_table[i * pq.M * pq.ksub];
473
- pq.compute_inner_prod_table (centroid.data(), tab);
474
- fvec_madd (pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
473
+ float* tab = &precomputed_table[i * pq.M * pq.ksub];
474
+ pq.compute_inner_prod_table(centroid.data(), tab);
475
+ fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
475
476
  }
476
477
  } else if (use_precomputed_table == 2) {
477
- const MultiIndexQuantizer *miq =
478
- dynamic_cast<const MultiIndexQuantizer *> (quantizer);
479
- FAISS_THROW_IF_NOT (miq);
480
- const ProductQuantizer &cpq = miq->pq;
481
- FAISS_THROW_IF_NOT (pq.M % cpq.M == 0);
478
+ const MultiIndexQuantizer* miq =
479
+ dynamic_cast<const MultiIndexQuantizer*>(quantizer);
480
+ FAISS_THROW_IF_NOT(miq);
481
+ const ProductQuantizer& cpq = miq->pq;
482
+ FAISS_THROW_IF_NOT(pq.M % cpq.M == 0);
482
483
 
483
484
  precomputed_table.resize(cpq.ksub * pq.M * pq.ksub);
484
485
 
485
486
  // reorder PQ centroid table
486
- std::vector<float> centroids (d * cpq.ksub, NAN);
487
+ std::vector<float> centroids(d * cpq.ksub, NAN);
487
488
 
488
489
  for (int m = 0; m < cpq.M; m++) {
489
490
  for (size_t i = 0; i < cpq.ksub; i++) {
490
- memcpy (centroids.data() + i * d + m * cpq.dsub,
491
- cpq.get_centroids (m, i),
492
- sizeof (*centroids.data()) * cpq.dsub);
491
+ memcpy(centroids.data() + i * d + m * cpq.dsub,
492
+ cpq.get_centroids(m, i),
493
+ sizeof(*centroids.data()) * cpq.dsub);
493
494
  }
494
495
  }
495
496
 
496
- pq.compute_inner_prod_tables (cpq.ksub, centroids.data (),
497
- precomputed_table.data ());
497
+ pq.compute_inner_prod_tables(
498
+ cpq.ksub, centroids.data(), precomputed_table.data());
498
499
 
499
500
  for (size_t i = 0; i < cpq.ksub; i++) {
500
- float *tab = &precomputed_table[i * pq.M * pq.ksub];
501
- fvec_madd (pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
501
+ float* tab = &precomputed_table[i * pq.M * pq.ksub];
502
+ fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
502
503
  }
503
-
504
504
  }
505
-
506
505
  }
507
506
 
508
- void IndexIVFPQ::precompute_table ()
509
- {
510
- initialize_IVFPQ_precomputed_table (
511
- use_precomputed_table, quantizer, pq, precomputed_table,
512
- verbose
513
- );
507
+ void IndexIVFPQ::precompute_table() {
508
+ initialize_IVFPQ_precomputed_table(
509
+ use_precomputed_table, quantizer, pq, precomputed_table, verbose);
514
510
  }
515
511
 
516
-
517
-
518
512
  namespace {
519
513
 
520
514
  using idx_t = Index::idx_t;
521
515
 
522
-
523
516
  #define TIC t0 = get_cycles()
524
- #define TOC get_cycles () - t0
525
-
526
-
517
+ #define TOC get_cycles() - t0
527
518
 
528
519
  /** QueryTables manages the various ways of searching an
529
520
  * IndexIVFPQ. The code contains a lot of branches, depending on:
@@ -533,43 +524,42 @@ using idx_t = Index::idx_t;
533
524
  * - polysemous_ht: are we filtering with polysemous codes?
534
525
  */
535
526
  struct QueryTables {
536
-
537
527
  /*****************************************************
538
528
  * General data from the IVFPQ
539
529
  *****************************************************/
540
530
 
541
- const IndexIVFPQ & ivfpq;
542
- const IVFSearchParameters *params;
531
+ const IndexIVFPQ& ivfpq;
532
+ const IVFSearchParameters* params;
543
533
 
544
534
  // copied from IndexIVFPQ for easier access
545
535
  int d;
546
- const ProductQuantizer & pq;
536
+ const ProductQuantizer& pq;
547
537
  MetricType metric_type;
548
538
  bool by_residual;
549
539
  int use_precomputed_table;
550
540
  int polysemous_ht;
551
541
 
552
542
  // pre-allocated data buffers
553
- float * sim_table, * sim_table_2;
554
- float * residual_vec, *decoded_vec;
543
+ float *sim_table, *sim_table_2;
544
+ float *residual_vec, *decoded_vec;
555
545
 
556
546
  // single data buffer
557
547
  std::vector<float> mem;
558
548
 
559
549
  // for table pointers
560
- std::vector<const float *> sim_table_ptrs;
561
-
562
- explicit QueryTables (const IndexIVFPQ & ivfpq,
563
- const IVFSearchParameters *params):
564
- ivfpq(ivfpq),
565
- d(ivfpq.d),
566
- pq (ivfpq.pq),
567
- metric_type (ivfpq.metric_type),
568
- by_residual (ivfpq.by_residual),
569
- use_precomputed_table (ivfpq.use_precomputed_table)
570
- {
571
- mem.resize (pq.ksub * pq.M * 2 + d * 2);
572
- sim_table = mem.data ();
550
+ std::vector<const float*> sim_table_ptrs;
551
+
552
+ explicit QueryTables(
553
+ const IndexIVFPQ& ivfpq,
554
+ const IVFSearchParameters* params)
555
+ : ivfpq(ivfpq),
556
+ d(ivfpq.d),
557
+ pq(ivfpq.pq),
558
+ metric_type(ivfpq.metric_type),
559
+ by_residual(ivfpq.by_residual),
560
+ use_precomputed_table(ivfpq.use_precomputed_table) {
561
+ mem.resize(pq.ksub * pq.M * 2 + d * 2);
562
+ sim_table = mem.data();
573
563
  sim_table_2 = sim_table + pq.ksub * pq.M;
574
564
  residual_vec = sim_table_2 + pq.ksub * pq.M;
575
565
  decoded_vec = residual_vec + d;
@@ -577,14 +567,14 @@ struct QueryTables {
577
567
  // for polysemous
578
568
  polysemous_ht = ivfpq.polysemous_ht;
579
569
  if (auto ivfpq_params =
580
- dynamic_cast<const IVFPQSearchParameters *>(params)) {
570
+ dynamic_cast<const IVFPQSearchParameters*>(params)) {
581
571
  polysemous_ht = ivfpq_params->polysemous_ht;
582
572
  }
583
- if (polysemous_ht != 0) {
584
- q_code.resize (pq.code_size);
573
+ if (polysemous_ht != 0) {
574
+ q_code.resize(pq.code_size);
585
575
  }
586
576
  init_list_cycles = 0;
587
- sim_table_ptrs.resize (pq.M);
577
+ sim_table_ptrs.resize(pq.M);
588
578
  }
589
579
 
590
580
  /*****************************************************
@@ -592,29 +582,29 @@ struct QueryTables {
592
582
  *****************************************************/
593
583
 
594
584
  // field specific to query
595
- const float * qi;
585
+ const float* qi;
596
586
 
597
587
  // query-specific intialization
598
- void init_query (const float * qi) {
588
+ void init_query(const float* qi) {
599
589
  this->qi = qi;
600
590
  if (metric_type == METRIC_INNER_PRODUCT)
601
- init_query_IP ();
591
+ init_query_IP();
602
592
  else
603
- init_query_L2 ();
593
+ init_query_L2();
604
594
  if (!by_residual && polysemous_ht != 0)
605
- pq.compute_code (qi, q_code.data());
595
+ pq.compute_code(qi, q_code.data());
606
596
  }
607
597
 
608
- void init_query_IP () {
598
+ void init_query_IP() {
609
599
  // precompute some tables specific to the query qi
610
- pq.compute_inner_prod_table (qi, sim_table);
600
+ pq.compute_inner_prod_table(qi, sim_table);
611
601
  }
612
602
 
613
- void init_query_L2 () {
603
+ void init_query_L2() {
614
604
  if (!by_residual) {
615
- pq.compute_distance_table (qi, sim_table);
605
+ pq.compute_distance_table(qi, sim_table);
616
606
  } else if (use_precomputed_table) {
617
- pq.compute_inner_prod_table (qi, sim_table_2);
607
+ pq.compute_inner_prod_table(qi, sim_table_2);
618
608
  }
619
609
  }
620
610
 
@@ -632,96 +622,95 @@ struct QueryTables {
632
622
  /// once we know the query and the centroid, we can prepare the
633
623
  /// sim_table that will be used for accumulation
634
624
  /// and dis0, the initial value
635
- float precompute_list_tables () {
625
+ float precompute_list_tables() {
636
626
  float dis0 = 0;
637
- uint64_t t0; TIC;
627
+ uint64_t t0;
628
+ TIC;
638
629
  if (by_residual) {
639
630
  if (metric_type == METRIC_INNER_PRODUCT)
640
- dis0 = precompute_list_tables_IP ();
631
+ dis0 = precompute_list_tables_IP();
641
632
  else
642
- dis0 = precompute_list_tables_L2 ();
633
+ dis0 = precompute_list_tables_L2();
643
634
  }
644
635
  init_list_cycles += TOC;
645
636
  return dis0;
646
- }
637
+ }
647
638
 
648
- float precompute_list_table_pointers () {
639
+ float precompute_list_table_pointers() {
649
640
  float dis0 = 0;
650
- uint64_t t0; TIC;
641
+ uint64_t t0;
642
+ TIC;
651
643
  if (by_residual) {
652
644
  if (metric_type == METRIC_INNER_PRODUCT)
653
- FAISS_THROW_MSG ("not implemented");
645
+ FAISS_THROW_MSG("not implemented");
654
646
  else
655
- dis0 = precompute_list_table_pointers_L2 ();
647
+ dis0 = precompute_list_table_pointers_L2();
656
648
  }
657
649
  init_list_cycles += TOC;
658
650
  return dis0;
659
- }
651
+ }
660
652
 
661
653
  /*****************************************************
662
654
  * compute tables for inner prod
663
655
  *****************************************************/
664
656
 
665
- float precompute_list_tables_IP ()
666
- {
657
+ float precompute_list_tables_IP() {
667
658
  // prepare the sim_table that will be used for accumulation
668
659
  // and dis0, the initial value
669
- ivfpq.quantizer->reconstruct (key, decoded_vec);
660
+ ivfpq.quantizer->reconstruct(key, decoded_vec);
670
661
  // decoded_vec = centroid
671
- float dis0 = fvec_inner_product (qi, decoded_vec, d);
662
+ float dis0 = fvec_inner_product(qi, decoded_vec, d);
672
663
 
673
664
  if (polysemous_ht) {
674
665
  for (int i = 0; i < d; i++) {
675
- residual_vec [i] = qi[i] - decoded_vec[i];
666
+ residual_vec[i] = qi[i] - decoded_vec[i];
676
667
  }
677
- pq.compute_code (residual_vec, q_code.data());
668
+ pq.compute_code(residual_vec, q_code.data());
678
669
  }
679
670
  return dis0;
680
671
  }
681
672
 
682
-
683
673
  /*****************************************************
684
674
  * compute tables for L2 distance
685
675
  *****************************************************/
686
676
 
687
- float precompute_list_tables_L2 ()
688
- {
677
+ float precompute_list_tables_L2() {
689
678
  float dis0 = 0;
690
679
 
691
680
  if (use_precomputed_table == 0 || use_precomputed_table == -1) {
692
- ivfpq.quantizer->compute_residual (qi, residual_vec, key);
693
- pq.compute_distance_table (residual_vec, sim_table);
681
+ ivfpq.quantizer->compute_residual(qi, residual_vec, key);
682
+ pq.compute_distance_table(residual_vec, sim_table);
694
683
 
695
684
  if (polysemous_ht != 0) {
696
- pq.compute_code (residual_vec, q_code.data());
685
+ pq.compute_code(residual_vec, q_code.data());
697
686
  }
698
687
 
699
688
  } else if (use_precomputed_table == 1) {
700
689
  dis0 = coarse_dis;
701
690
 
702
- fvec_madd (
691
+ fvec_madd(
703
692
  pq.M * pq.ksub,
704
693
  ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
705
- -2.0, sim_table_2,
706
- sim_table
707
- );
694
+ -2.0,
695
+ sim_table_2,
696
+ sim_table);
708
697
 
709
698
  if (polysemous_ht != 0) {
710
- ivfpq.quantizer->compute_residual (qi, residual_vec, key);
711
- pq.compute_code (residual_vec, q_code.data());
699
+ ivfpq.quantizer->compute_residual(qi, residual_vec, key);
700
+ pq.compute_code(residual_vec, q_code.data());
712
701
  }
713
702
 
714
703
  } else if (use_precomputed_table == 2) {
715
704
  dis0 = coarse_dis;
716
705
 
717
- const MultiIndexQuantizer *miq =
718
- dynamic_cast<const MultiIndexQuantizer *> (ivfpq.quantizer);
719
- FAISS_THROW_IF_NOT (miq);
720
- const ProductQuantizer &cpq = miq->pq;
706
+ const MultiIndexQuantizer* miq =
707
+ dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
708
+ FAISS_THROW_IF_NOT(miq);
709
+ const ProductQuantizer& cpq = miq->pq;
721
710
  int Mf = pq.M / cpq.M;
722
711
 
723
- const float *qtab = sim_table_2; // query-specific table
724
- float *ltab = sim_table; // (output) list-specific table
712
+ const float* qtab = sim_table_2; // query-specific table
713
+ float* ltab = sim_table; // (output) list-specific table
725
714
 
726
715
  long k = key;
727
716
  for (int cm = 0; cm < cpq.M; cm++) {
@@ -730,54 +719,48 @@ struct QueryTables {
730
719
  k >>= cpq.nbits;
731
720
 
732
721
  // get corresponding table
733
- const float *pc = ivfpq.precomputed_table.data() +
734
- (ki * pq.M + cm * Mf) * pq.ksub;
722
+ const float* pc = ivfpq.precomputed_table.data() +
723
+ (ki * pq.M + cm * Mf) * pq.ksub;
735
724
 
736
725
  if (polysemous_ht == 0) {
737
-
738
726
  // sum up with query-specific table
739
- fvec_madd (Mf * pq.ksub,
740
- pc,
741
- -2.0, qtab,
742
- ltab);
727
+ fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
743
728
  ltab += Mf * pq.ksub;
744
729
  qtab += Mf * pq.ksub;
745
730
  } else {
746
731
  for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
747
- q_code[m] = fvec_madd_and_argmin
748
- (pq.ksub, pc, -2, qtab, ltab);
732
+ q_code[m] = fvec_madd_and_argmin(
733
+ pq.ksub, pc, -2, qtab, ltab);
749
734
  pc += pq.ksub;
750
735
  ltab += pq.ksub;
751
736
  qtab += pq.ksub;
752
737
  }
753
738
  }
754
-
755
739
  }
756
740
  }
757
741
 
758
742
  return dis0;
759
743
  }
760
744
 
761
- float precompute_list_table_pointers_L2 ()
762
- {
745
+ float precompute_list_table_pointers_L2() {
763
746
  float dis0 = 0;
764
747
 
765
748
  if (use_precomputed_table == 1) {
766
749
  dis0 = coarse_dis;
767
750
 
768
- const float * s = ivfpq.precomputed_table.data() +
769
- key * pq.ksub * pq.M;
751
+ const float* s =
752
+ ivfpq.precomputed_table.data() + key * pq.ksub * pq.M;
770
753
  for (int m = 0; m < pq.M; m++) {
771
- sim_table_ptrs [m] = s;
754
+ sim_table_ptrs[m] = s;
772
755
  s += pq.ksub;
773
756
  }
774
757
  } else if (use_precomputed_table == 2) {
775
758
  dis0 = coarse_dis;
776
759
 
777
- const MultiIndexQuantizer *miq =
778
- dynamic_cast<const MultiIndexQuantizer *> (ivfpq.quantizer);
779
- FAISS_THROW_IF_NOT (miq);
780
- const ProductQuantizer &cpq = miq->pq;
760
+ const MultiIndexQuantizer* miq =
761
+ dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
762
+ FAISS_THROW_IF_NOT(miq);
763
+ const ProductQuantizer& cpq = miq->pq;
781
764
  int Mf = pq.M / cpq.M;
782
765
 
783
766
  long k = key;
@@ -786,21 +769,21 @@ struct QueryTables {
786
769
  int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
787
770
  k >>= cpq.nbits;
788
771
 
789
- const float *pc = ivfpq.precomputed_table.data() +
790
- (ki * pq.M + cm * Mf) * pq.ksub;
772
+ const float* pc = ivfpq.precomputed_table.data() +
773
+ (ki * pq.M + cm * Mf) * pq.ksub;
791
774
 
792
775
  for (int m = m0; m < m0 + Mf; m++) {
793
- sim_table_ptrs [m] = pc;
776
+ sim_table_ptrs[m] = pc;
794
777
  pc += pq.ksub;
795
778
  }
796
779
  m0 += Mf;
797
780
  }
798
781
  } else {
799
- FAISS_THROW_MSG ("need precomputed tables");
782
+ FAISS_THROW_MSG("need precomputed tables");
800
783
  }
801
784
 
802
785
  if (polysemous_ht) {
803
- FAISS_THROW_MSG ("not implemented");
786
+ FAISS_THROW_MSG("not implemented");
804
787
  // Not clear that it makes sense to implemente this,
805
788
  // because it costs M * ksub, which is what we wanted to
806
789
  // avoid with the tables pointers.
@@ -808,82 +791,72 @@ struct QueryTables {
808
791
 
809
792
  return dis0;
810
793
  }
811
-
812
-
813
794
  };
814
795
 
815
-
816
-
817
- template<class C>
796
+ template <class C>
818
797
  struct KnnSearchResults {
819
798
  idx_t key;
820
- const idx_t *ids;
799
+ const idx_t* ids;
821
800
 
822
801
  // heap params
823
802
  size_t k;
824
- float * heap_sim;
825
- idx_t * heap_ids;
803
+ float* heap_sim;
804
+ idx_t* heap_ids;
826
805
 
827
806
  size_t nup;
828
807
 
829
- inline void add (idx_t j, float dis) {
830
- if (C::cmp (heap_sim[0], dis)) {
831
- idx_t id = ids ? ids[j] : lo_build (key, j);
832
- heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
808
+ inline void add(idx_t j, float dis) {
809
+ if (C::cmp(heap_sim[0], dis)) {
810
+ idx_t id = ids ? ids[j] : lo_build(key, j);
811
+ heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
833
812
  nup++;
834
813
  }
835
814
  }
836
-
837
815
  };
838
816
 
839
- template<class C>
817
+ template <class C>
840
818
  struct RangeSearchResults {
841
819
  idx_t key;
842
- const idx_t *ids;
820
+ const idx_t* ids;
843
821
 
844
822
  // wrapped result structure
845
823
  float radius;
846
- RangeQueryResult & rres;
824
+ RangeQueryResult& rres;
847
825
 
848
- inline void add (idx_t j, float dis) {
849
- if (C::cmp (radius, dis)) {
850
- idx_t id = ids ? ids[j] : lo_build (key, j);
851
- rres.add (dis, id);
826
+ inline void add(idx_t j, float dis) {
827
+ if (C::cmp(radius, dis)) {
828
+ idx_t id = ids ? ids[j] : lo_build(key, j);
829
+ rres.add(dis, id);
852
830
  }
853
831
  }
854
832
  };
855
833
 
856
-
857
-
858
834
  /*****************************************************
859
835
  * Scaning the codes.
860
836
  * The scanning functions call their favorite precompute_*
861
837
  * function to precompute the tables they need.
862
838
  *****************************************************/
863
839
  template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
864
- struct IVFPQScannerT: QueryTables {
865
-
866
- const uint8_t * list_codes;
867
- const IDType * list_ids;
840
+ struct IVFPQScannerT : QueryTables {
841
+ const uint8_t* list_codes;
842
+ const IDType* list_ids;
868
843
  size_t list_size;
869
844
 
870
- IVFPQScannerT (const IndexIVFPQ & ivfpq, const IVFSearchParameters *params):
871
- QueryTables (ivfpq, params)
872
- {
845
+ IVFPQScannerT(const IndexIVFPQ& ivfpq, const IVFSearchParameters* params)
846
+ : QueryTables(ivfpq, params) {
873
847
  assert(METRIC_TYPE == metric_type);
874
848
  }
875
849
 
876
850
  float dis0;
877
851
 
878
- void init_list (idx_t list_no, float coarse_dis,
879
- int mode) {
852
+ void init_list(idx_t list_no, float coarse_dis, int mode) {
880
853
  this->key = list_no;
881
854
  this->coarse_dis = coarse_dis;
882
855
 
883
856
  if (mode == 2) {
884
- dis0 = precompute_list_tables ();
857
+ dis0 = precompute_list_tables();
885
858
  } else if (mode == 1) {
886
- dis0 = precompute_list_table_pointers ();
859
+ dis0 = precompute_list_table_pointers();
887
860
  }
888
861
  }
889
862
 
@@ -892,15 +865,16 @@ struct IVFPQScannerT: QueryTables {
892
865
  *****************************************************/
893
866
 
894
867
  /// version of the scan where we use precomputed tables
895
- template<class SearchResultType>
896
- void scan_list_with_table (size_t ncode, const uint8_t *codes,
897
- SearchResultType & res) const
898
- {
868
+ template <class SearchResultType>
869
+ void scan_list_with_table(
870
+ size_t ncode,
871
+ const uint8_t* codes,
872
+ SearchResultType& res) const {
899
873
  for (size_t j = 0; j < ncode; j++) {
900
874
  PQDecoder decoder(codes, pq.nbits);
901
875
  codes += pq.code_size;
902
876
  float dis = dis0;
903
- const float *tab = sim_table;
877
+ const float* tab = sim_table;
904
878
 
905
879
  for (size_t m = 0; m < pq.M; m++) {
906
880
  dis += tab[decoder.decode()];
@@ -911,43 +885,43 @@ struct IVFPQScannerT: QueryTables {
911
885
  }
912
886
  }
913
887
 
914
-
915
888
  /// tables are not precomputed, but pointers are provided to the
916
889
  /// relevant X_c|x_r tables
917
- template<class SearchResultType>
918
- void scan_list_with_pointer (size_t ncode, const uint8_t *codes,
919
- SearchResultType & res) const
920
- {
890
+ template <class SearchResultType>
891
+ void scan_list_with_pointer(
892
+ size_t ncode,
893
+ const uint8_t* codes,
894
+ SearchResultType& res) const {
921
895
  for (size_t j = 0; j < ncode; j++) {
922
896
  PQDecoder decoder(codes, pq.nbits);
923
897
  codes += pq.code_size;
924
898
 
925
899
  float dis = dis0;
926
- const float *tab = sim_table_2;
900
+ const float* tab = sim_table_2;
927
901
 
928
902
  for (size_t m = 0; m < pq.M; m++) {
929
903
  int ci = decoder.decode();
930
- dis += sim_table_ptrs [m][ci] - 2 * tab [ci];
904
+ dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
931
905
  tab += pq.ksub;
932
906
  }
933
- res.add (j, dis);
907
+ res.add(j, dis);
934
908
  }
935
909
  }
936
910
 
937
-
938
911
  /// nothing is precomputed: access residuals on-the-fly
939
- template<class SearchResultType>
940
- void scan_on_the_fly_dist (size_t ncode, const uint8_t *codes,
941
- SearchResultType &res) const
942
- {
943
- const float *dvec;
912
+ template <class SearchResultType>
913
+ void scan_on_the_fly_dist(
914
+ size_t ncode,
915
+ const uint8_t* codes,
916
+ SearchResultType& res) const {
917
+ const float* dvec;
944
918
  float dis0 = 0;
945
919
  if (by_residual) {
946
920
  if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
947
- ivfpq.quantizer->reconstruct (key, residual_vec);
948
- dis0 = fvec_inner_product (residual_vec, qi, d);
921
+ ivfpq.quantizer->reconstruct(key, residual_vec);
922
+ dis0 = fvec_inner_product(residual_vec, qi, d);
949
923
  } else {
950
- ivfpq.quantizer->compute_residual (qi, residual_vec, key);
924
+ ivfpq.quantizer->compute_residual(qi, residual_vec, key);
951
925
  }
952
926
  dvec = residual_vec;
953
927
  } else {
@@ -956,17 +930,16 @@ struct IVFPQScannerT: QueryTables {
956
930
  }
957
931
 
958
932
  for (size_t j = 0; j < ncode; j++) {
959
-
960
- pq.decode (codes, decoded_vec);
933
+ pq.decode(codes, decoded_vec);
961
934
  codes += pq.code_size;
962
935
 
963
936
  float dis;
964
937
  if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
965
- dis = dis0 + fvec_inner_product (decoded_vec, qi, d);
938
+ dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
966
939
  } else {
967
- dis = fvec_L2sqr (decoded_vec, dvec, d);
940
+ dis = fvec_L2sqr(decoded_vec, dvec, d);
968
941
  }
969
- res.add (j, dis);
942
+ res.add(j, dis);
970
943
  }
971
944
  }
972
945
 
@@ -975,110 +948,98 @@ struct IVFPQScannerT: QueryTables {
975
948
  *****************************************************/
976
949
 
977
950
  template <class HammingComputer, class SearchResultType>
978
- void scan_list_polysemous_hc (
979
- size_t ncode, const uint8_t *codes,
980
- SearchResultType & res) const
981
- {
951
+ void scan_list_polysemous_hc(
952
+ size_t ncode,
953
+ const uint8_t* codes,
954
+ SearchResultType& res) const {
982
955
  int ht = ivfpq.polysemous_ht;
983
956
  size_t n_hamming_pass = 0, nup = 0;
984
957
 
985
958
  int code_size = pq.code_size;
986
959
 
987
- HammingComputer hc (q_code.data(), code_size);
960
+ HammingComputer hc(q_code.data(), code_size);
988
961
 
989
962
  for (size_t j = 0; j < ncode; j++) {
990
- const uint8_t *b_code = codes;
991
- int hd = hc.hamming (b_code);
963
+ const uint8_t* b_code = codes;
964
+ int hd = hc.hamming(b_code);
992
965
  if (hd < ht) {
993
- n_hamming_pass ++;
966
+ n_hamming_pass++;
994
967
  PQDecoder decoder(codes, pq.nbits);
995
968
 
996
969
  float dis = dis0;
997
- const float *tab = sim_table;
970
+ const float* tab = sim_table;
998
971
 
999
972
  for (size_t m = 0; m < pq.M; m++) {
1000
973
  dis += tab[decoder.decode()];
1001
974
  tab += pq.ksub;
1002
975
  }
1003
976
 
1004
- res.add (j, dis);
977
+ res.add(j, dis);
1005
978
  }
1006
979
  codes += code_size;
1007
980
  }
1008
981
  #pragma omp critical
1009
- {
1010
- indexIVFPQ_stats.n_hamming_pass += n_hamming_pass;
1011
- }
982
+ { indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
1012
983
  }
1013
984
 
1014
- template<class SearchResultType>
1015
- void scan_list_polysemous (
1016
- size_t ncode, const uint8_t *codes,
1017
- SearchResultType &res) const
1018
- {
985
+ template <class SearchResultType>
986
+ void scan_list_polysemous(
987
+ size_t ncode,
988
+ const uint8_t* codes,
989
+ SearchResultType& res) const {
1019
990
  switch (pq.code_size) {
1020
991
  #define HANDLE_CODE_SIZE(cs) \
1021
- case cs: \
1022
- scan_list_polysemous_hc \
1023
- <HammingComputer ## cs, SearchResultType> \
1024
- (ncode, codes, res); \
1025
- break
1026
- HANDLE_CODE_SIZE(4);
1027
- HANDLE_CODE_SIZE(8);
1028
- HANDLE_CODE_SIZE(16);
1029
- HANDLE_CODE_SIZE(20);
1030
- HANDLE_CODE_SIZE(32);
1031
- HANDLE_CODE_SIZE(64);
992
+ case cs: \
993
+ scan_list_polysemous_hc<HammingComputer##cs, SearchResultType>( \
994
+ ncode, codes, res); \
995
+ break
996
+ HANDLE_CODE_SIZE(4);
997
+ HANDLE_CODE_SIZE(8);
998
+ HANDLE_CODE_SIZE(16);
999
+ HANDLE_CODE_SIZE(20);
1000
+ HANDLE_CODE_SIZE(32);
1001
+ HANDLE_CODE_SIZE(64);
1032
1002
  #undef HANDLE_CODE_SIZE
1033
- default:
1034
- if (pq.code_size % 8 == 0)
1035
- scan_list_polysemous_hc
1036
- <HammingComputerM8, SearchResultType>
1037
- (ncode, codes, res);
1038
- else
1039
- scan_list_polysemous_hc
1040
- <HammingComputerM4, SearchResultType>
1041
- (ncode, codes, res);
1042
- break;
1003
+ default:
1004
+ scan_list_polysemous_hc<
1005
+ HammingComputerDefault,
1006
+ SearchResultType>(ncode, codes, res);
1007
+ break;
1043
1008
  }
1044
1009
  }
1045
-
1046
1010
  };
1047
1011
 
1048
-
1049
1012
  /* We put as many parameters as possible in template. Hopefully the
1050
1013
  * gain in runtime is worth the code bloat. C is the comparator < or
1051
1014
  * >, it is directly related to METRIC_TYPE. precompute_mode is how
1052
1015
  * much we precompute (2 = precompute distance tables, 1 = precompute
1053
1016
  * pointers to distances, 0 = compute distances one by one).
1054
1017
  * Currently only 2 is supported */
1055
- template<MetricType METRIC_TYPE, class C, class PQDecoder>
1056
- struct IVFPQScanner:
1057
- IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>,
1058
- InvertedListScanner
1059
- {
1018
+ template <MetricType METRIC_TYPE, class C, class PQDecoder>
1019
+ struct IVFPQScanner : IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>,
1020
+ InvertedListScanner {
1060
1021
  bool store_pairs;
1061
1022
  int precompute_mode;
1062
1023
 
1063
- IVFPQScanner(const IndexIVFPQ & ivfpq, bool store_pairs,
1064
- int precompute_mode):
1065
- IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
1066
- store_pairs(store_pairs), precompute_mode(precompute_mode)
1067
- {
1068
- }
1024
+ IVFPQScanner(const IndexIVFPQ& ivfpq, bool store_pairs, int precompute_mode)
1025
+ : IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>(
1026
+ ivfpq,
1027
+ nullptr),
1028
+ store_pairs(store_pairs),
1029
+ precompute_mode(precompute_mode) {}
1069
1030
 
1070
- void set_query (const float *query) override {
1071
- this->init_query (query);
1031
+ void set_query(const float* query) override {
1032
+ this->init_query(query);
1072
1033
  }
1073
1034
 
1074
- void set_list (idx_t list_no, float coarse_dis) override {
1075
- this->init_list (list_no, coarse_dis, precompute_mode);
1035
+ void set_list(idx_t list_no, float coarse_dis) override {
1036
+ this->init_list(list_no, coarse_dis, precompute_mode);
1076
1037
  }
1077
1038
 
1078
- float distance_to_code (const uint8_t *code) const override {
1039
+ float distance_to_code(const uint8_t* code) const override {
1079
1040
  assert(precompute_mode == 2);
1080
1041
  float dis = this->dis0;
1081
- const float *tab = this->sim_table;
1042
+ const float* tab = this->sim_table;
1082
1043
  PQDecoder decoder(code, this->pq.nbits);
1083
1044
 
1084
1045
  for (size_t m = 0; m < this->pq.M; m++) {
@@ -1088,112 +1049,100 @@ struct IVFPQScanner:
1088
1049
  return dis;
1089
1050
  }
1090
1051
 
1091
- size_t scan_codes (size_t ncode,
1092
- const uint8_t *codes,
1093
- const idx_t *ids,
1094
- float *heap_sim, idx_t *heap_ids,
1095
- size_t k) const override
1096
- {
1052
+ size_t scan_codes(
1053
+ size_t ncode,
1054
+ const uint8_t* codes,
1055
+ const idx_t* ids,
1056
+ float* heap_sim,
1057
+ idx_t* heap_ids,
1058
+ size_t k) const override {
1097
1059
  KnnSearchResults<C> res = {
1098
- /* key */ this->key,
1099
- /* ids */ this->store_pairs ? nullptr : ids,
1100
- /* k */ k,
1101
- /* heap_sim */ heap_sim,
1102
- /* heap_ids */ heap_ids,
1103
- /* nup */ 0
1104
- };
1060
+ /* key */ this->key,
1061
+ /* ids */ this->store_pairs ? nullptr : ids,
1062
+ /* k */ k,
1063
+ /* heap_sim */ heap_sim,
1064
+ /* heap_ids */ heap_ids,
1065
+ /* nup */ 0};
1105
1066
 
1106
1067
  if (this->polysemous_ht > 0) {
1107
1068
  assert(precompute_mode == 2);
1108
- this->scan_list_polysemous (ncode, codes, res);
1069
+ this->scan_list_polysemous(ncode, codes, res);
1109
1070
  } else if (precompute_mode == 2) {
1110
- this->scan_list_with_table (ncode, codes, res);
1071
+ this->scan_list_with_table(ncode, codes, res);
1111
1072
  } else if (precompute_mode == 1) {
1112
- this->scan_list_with_pointer (ncode, codes, res);
1073
+ this->scan_list_with_pointer(ncode, codes, res);
1113
1074
  } else if (precompute_mode == 0) {
1114
- this->scan_on_the_fly_dist (ncode, codes, res);
1075
+ this->scan_on_the_fly_dist(ncode, codes, res);
1115
1076
  } else {
1116
1077
  FAISS_THROW_MSG("bad precomp mode");
1117
1078
  }
1118
1079
  return res.nup;
1119
1080
  }
1120
1081
 
1121
- void scan_codes_range (size_t ncode,
1122
- const uint8_t *codes,
1123
- const idx_t *ids,
1124
- float radius,
1125
- RangeQueryResult & rres) const override
1126
- {
1082
+ void scan_codes_range(
1083
+ size_t ncode,
1084
+ const uint8_t* codes,
1085
+ const idx_t* ids,
1086
+ float radius,
1087
+ RangeQueryResult& rres) const override {
1127
1088
  RangeSearchResults<C> res = {
1128
- /* key */ this->key,
1129
- /* ids */ this->store_pairs ? nullptr : ids,
1130
- /* radius */ radius,
1131
- /* rres */ rres
1132
- };
1089
+ /* key */ this->key,
1090
+ /* ids */ this->store_pairs ? nullptr : ids,
1091
+ /* radius */ radius,
1092
+ /* rres */ rres};
1133
1093
 
1134
1094
  if (this->polysemous_ht > 0) {
1135
1095
  assert(precompute_mode == 2);
1136
- this->scan_list_polysemous (ncode, codes, res);
1096
+ this->scan_list_polysemous(ncode, codes, res);
1137
1097
  } else if (precompute_mode == 2) {
1138
- this->scan_list_with_table (ncode, codes, res);
1098
+ this->scan_list_with_table(ncode, codes, res);
1139
1099
  } else if (precompute_mode == 1) {
1140
- this->scan_list_with_pointer (ncode, codes, res);
1100
+ this->scan_list_with_pointer(ncode, codes, res);
1141
1101
  } else if (precompute_mode == 0) {
1142
- this->scan_on_the_fly_dist (ncode, codes, res);
1102
+ this->scan_on_the_fly_dist(ncode, codes, res);
1143
1103
  } else {
1144
1104
  FAISS_THROW_MSG("bad precomp mode");
1145
1105
  }
1146
-
1147
1106
  }
1148
1107
  };
1149
1108
 
1150
- template<class PQDecoder>
1151
- InvertedListScanner *get_InvertedListScanner1 (const IndexIVFPQ &index,
1152
- bool store_pairs)
1153
- {
1154
-
1155
- if (index.metric_type == METRIC_INNER_PRODUCT) {
1156
- return new IVFPQScanner
1157
- <METRIC_INNER_PRODUCT, CMin<float, idx_t>, PQDecoder>
1158
- (index, store_pairs, 2);
1109
+ template <class PQDecoder>
1110
+ InvertedListScanner* get_InvertedListScanner1(
1111
+ const IndexIVFPQ& index,
1112
+ bool store_pairs) {
1113
+ if (index.metric_type == METRIC_INNER_PRODUCT) {
1114
+ return new IVFPQScanner<
1115
+ METRIC_INNER_PRODUCT,
1116
+ CMin<float, idx_t>,
1117
+ PQDecoder>(index, store_pairs, 2);
1159
1118
  } else if (index.metric_type == METRIC_L2) {
1160
- return new IVFPQScanner
1161
- <METRIC_L2, CMax<float, idx_t>, PQDecoder>
1162
- (index, store_pairs, 2);
1119
+ return new IVFPQScanner<METRIC_L2, CMax<float, idx_t>, PQDecoder>(
1120
+ index, store_pairs, 2);
1163
1121
  }
1164
1122
  return nullptr;
1165
1123
  }
1166
1124
 
1167
-
1168
1125
  } // anonymous namespace
1169
1126
 
1170
- InvertedListScanner *
1171
- IndexIVFPQ::get_InvertedListScanner (bool store_pairs) const
1172
- {
1173
-
1127
+ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
1128
+ bool store_pairs) const {
1174
1129
  if (pq.nbits == 8) {
1175
- return get_InvertedListScanner1<PQDecoder8> (*this, store_pairs);
1130
+ return get_InvertedListScanner1<PQDecoder8>(*this, store_pairs);
1176
1131
  } else if (pq.nbits == 16) {
1177
- return get_InvertedListScanner1<PQDecoder16> (*this, store_pairs);
1132
+ return get_InvertedListScanner1<PQDecoder16>(*this, store_pairs);
1178
1133
  } else {
1179
- return get_InvertedListScanner1<PQDecoderGeneric> (*this, store_pairs);
1134
+ return get_InvertedListScanner1<PQDecoderGeneric>(*this, store_pairs);
1180
1135
  }
1181
1136
  return nullptr;
1182
-
1183
1137
  }
1184
1138
 
1185
-
1186
-
1187
1139
  IndexIVFPQStats indexIVFPQ_stats;
1188
1140
 
1189
- void IndexIVFPQStats::reset () {
1190
- memset (this, 0, sizeof (*this));
1141
+ void IndexIVFPQStats::reset() {
1142
+ memset(this, 0, sizeof(*this));
1191
1143
  }
1192
1144
 
1193
-
1194
-
1195
- IndexIVFPQ::IndexIVFPQ ()
1196
- {
1145
+ IndexIVFPQ::IndexIVFPQ() {
1197
1146
  // initialize some runtime values
1198
1147
  use_precomputed_table = 0;
1199
1148
  scan_table_threshold = 0;
@@ -1202,43 +1151,40 @@ IndexIVFPQ::IndexIVFPQ ()
1202
1151
  polysemous_training = nullptr;
1203
1152
  }
1204
1153
 
1205
-
1206
1154
  struct CodeCmp {
1207
- const uint8_t *tab;
1155
+ const uint8_t* tab;
1208
1156
  size_t code_size;
1209
- bool operator () (int a, int b) const {
1210
- return cmp (a, b) > 0;
1157
+ bool operator()(int a, int b) const {
1158
+ return cmp(a, b) > 0;
1211
1159
  }
1212
- int cmp (int a, int b) const {
1213
- return memcmp (tab + a * code_size, tab + b * code_size,
1214
- code_size);
1160
+ int cmp(int a, int b) const {
1161
+ return memcmp(tab + a * code_size, tab + b * code_size, code_size);
1215
1162
  }
1216
1163
  };
1217
1164
 
1218
-
1219
- size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
1220
- {
1165
+ size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
1221
1166
  size_t ngroup = 0;
1222
1167
  lims[0] = 0;
1223
1168
  for (size_t list_no = 0; list_no < nlist; list_no++) {
1224
- size_t n = invlists->list_size (list_no);
1225
- std::vector<int> ord (n);
1226
- for (int i = 0; i < n; i++) ord[i] = i;
1227
- InvertedLists::ScopedCodes codes (invlists, list_no);
1228
- CodeCmp cs = { codes.get(), code_size };
1229
- std::sort (ord.begin(), ord.end(), cs);
1230
-
1231
- InvertedLists::ScopedIds list_ids (invlists, list_no);
1232
- int prev = -1; // all elements from prev to i-1 are equal
1169
+ size_t n = invlists->list_size(list_no);
1170
+ std::vector<int> ord(n);
1171
+ for (int i = 0; i < n; i++)
1172
+ ord[i] = i;
1173
+ InvertedLists::ScopedCodes codes(invlists, list_no);
1174
+ CodeCmp cs = {codes.get(), code_size};
1175
+ std::sort(ord.begin(), ord.end(), cs);
1176
+
1177
+ InvertedLists::ScopedIds list_ids(invlists, list_no);
1178
+ int prev = -1; // all elements from prev to i-1 are equal
1233
1179
  for (int i = 0; i < n; i++) {
1234
- if (prev >= 0 && cs.cmp (ord [prev], ord [i]) == 0) {
1180
+ if (prev >= 0 && cs.cmp(ord[prev], ord[i]) == 0) {
1235
1181
  // same as previous => remember
1236
1182
  if (prev + 1 == i) { // start new group
1237
1183
  ngroup++;
1238
1184
  lims[ngroup] = lims[ngroup - 1];
1239
- dup_ids [lims [ngroup]++] = list_ids [ord [prev]];
1185
+ dup_ids[lims[ngroup]++] = list_ids[ord[prev]];
1240
1186
  }
1241
- dup_ids [lims [ngroup]++] = list_ids [ord [i]];
1187
+ dup_ids[lims[ngroup]++] = list_ids[ord[i]];
1242
1188
  } else { // not same as previous.
1243
1189
  prev = i;
1244
1190
  }
@@ -1247,9 +1193,4 @@ size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
1247
1193
  return ngroup;
1248
1194
  }
1249
1195
 
1250
-
1251
-
1252
-
1253
-
1254
-
1255
1196
  } // namespace faiss