faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -10,11 +10,11 @@
10
10
  #ifndef META_INDEXES_H
11
11
  #define META_INDEXES_H
12
12
 
13
- #include <vector>
14
- #include <unordered_map>
15
13
  #include <faiss/Index.h>
16
- #include <faiss/IndexShards.h>
17
14
  #include <faiss/IndexReplicas.h>
15
+ #include <faiss/IndexShards.h>
16
+ #include <unordered_map>
17
+ #include <vector>
18
18
 
19
19
  namespace faiss {
20
20
 
@@ -25,22 +25,25 @@ struct IndexIDMapTemplate : IndexT {
25
25
  using component_t = typename IndexT::component_t;
26
26
  using distance_t = typename IndexT::distance_t;
27
27
 
28
- IndexT * index; ///! the sub-index
29
- bool own_fields; ///! whether pointers are deleted in destructo
28
+ IndexT* index; ///! the sub-index
29
+ bool own_fields; ///! whether pointers are deleted in destructo
30
30
  std::vector<idx_t> id_map;
31
31
 
32
- explicit IndexIDMapTemplate (IndexT *index);
32
+ explicit IndexIDMapTemplate(IndexT* index);
33
33
 
34
34
  /// @param xids if non-null, ids to store for the vectors (size n)
35
- void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override;
35
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
36
+ override;
36
37
 
37
38
  /// this will fail. Use add_with_ids
38
39
  void add(idx_t n, const component_t* x) override;
39
40
 
40
41
  void search(
41
- idx_t n, const component_t* x, idx_t k,
42
- distance_t* distances,
43
- idx_t* labels) const override;
42
+ idx_t n,
43
+ const component_t* x,
44
+ idx_t k,
45
+ distance_t* distances,
46
+ idx_t* labels) const override;
44
47
 
45
48
  void train(idx_t n, const component_t* x) override;
46
49
 
@@ -49,17 +52,22 @@ struct IndexIDMapTemplate : IndexT {
49
52
  /// remove ids adapted to IndexFlat
50
53
  size_t remove_ids(const IDSelector& sel) override;
51
54
 
52
- void range_search (idx_t n, const component_t *x, distance_t radius,
53
- RangeSearchResult *result) const override;
54
-
55
- ~IndexIDMapTemplate () override;
56
- IndexIDMapTemplate () {own_fields=false; index=nullptr; }
55
+ void range_search(
56
+ idx_t n,
57
+ const component_t* x,
58
+ distance_t radius,
59
+ RangeSearchResult* result) const override;
60
+
61
+ ~IndexIDMapTemplate() override;
62
+ IndexIDMapTemplate() {
63
+ own_fields = false;
64
+ index = nullptr;
65
+ }
57
66
  };
58
67
 
59
68
  using IndexIDMap = IndexIDMapTemplate<Index>;
60
69
  using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
61
70
 
62
-
63
71
  /** same as IndexIDMap but also provides an efficient reconstruction
64
72
  * implementation via a 2-way index */
65
73
  template <typename IndexT>
@@ -70,47 +78,47 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
70
78
 
71
79
  std::unordered_map<idx_t, idx_t> rev_map;
72
80
 
73
- explicit IndexIDMap2Template (IndexT *index);
81
+ explicit IndexIDMap2Template(IndexT* index);
74
82
 
75
83
  /// make the rev_map from scratch
76
- void construct_rev_map ();
84
+ void construct_rev_map();
77
85
 
78
- void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override;
86
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
87
+ override;
79
88
 
80
89
  size_t remove_ids(const IDSelector& sel) override;
81
90
 
82
- void reconstruct (idx_t key, component_t * recons) const override;
91
+ void reconstruct(idx_t key, component_t* recons) const override;
83
92
 
84
93
  ~IndexIDMap2Template() override {}
85
- IndexIDMap2Template () {}
94
+ IndexIDMap2Template() {}
86
95
  };
87
96
 
88
97
  using IndexIDMap2 = IndexIDMap2Template<Index>;
89
98
  using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
90
99
 
91
-
92
100
  /** splits input vectors in segments and assigns each segment to a sub-index
93
101
  * used to distribute a MultiIndexQuantizer
94
102
  */
95
- struct IndexSplitVectors: Index {
103
+ struct IndexSplitVectors : Index {
96
104
  bool own_fields;
97
105
  bool threaded;
98
106
  std::vector<Index*> sub_indexes;
99
- idx_t sum_d; /// sum of dimensions seen so far
107
+ idx_t sum_d; /// sum of dimensions seen so far
100
108
 
101
- explicit IndexSplitVectors (idx_t d, bool threaded = false);
109
+ explicit IndexSplitVectors(idx_t d, bool threaded = false);
102
110
 
103
- void add_sub_index (Index *);
104
- void sync_with_sub_indexes ();
111
+ void add_sub_index(Index*);
112
+ void sync_with_sub_indexes();
105
113
 
106
114
  void add(idx_t n, const float* x) override;
107
115
 
108
116
  void search(
109
- idx_t n,
110
- const float* x,
111
- idx_t k,
112
- float* distances,
113
- idx_t* labels) const override;
117
+ idx_t n,
118
+ const float* x,
119
+ idx_t k,
120
+ float* distances,
121
+ idx_t* labels) const override;
114
122
 
115
123
  void train(idx_t n, const float* x) override;
116
124
 
@@ -119,8 +127,6 @@ struct IndexSplitVectors: Index {
119
127
  ~IndexSplitVectors() override;
120
128
  };
121
129
 
122
-
123
130
  } // namespace faiss
124
131
 
125
-
126
132
  #endif
@@ -18,12 +18,12 @@ namespace faiss {
18
18
  /// (brute-force) indices supporting additional metric types for vector
19
19
  /// comparison.
20
20
  enum MetricType {
21
- METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
22
- METRIC_L2 = 1, ///< squared L2 search
23
- METRIC_L1, ///< L1 (aka cityblock)
24
- METRIC_Linf, ///< infinity distance
25
- METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
26
- /// metric_arg
21
+ METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
22
+ METRIC_L2 = 1, ///< squared L2 search
23
+ METRIC_L1, ///< L1 (aka cityblock)
24
+ METRIC_Linf, ///< infinity distance
25
+ METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
26
+ /// metric_arg
27
27
 
28
28
  /// some additional metrics defined in scipy.spatial.distance
29
29
  METRIC_Canberra = 20,
@@ -31,6 +31,6 @@ enum MetricType {
31
31
  METRIC_JensenShannon,
32
32
  };
33
33
 
34
- }
34
+ } // namespace faiss
35
35
 
36
36
  #endif
@@ -10,20 +10,19 @@
10
10
  #include <faiss/VectorTransform.h>
11
11
 
12
12
  #include <cinttypes>
13
- #include <cstdio>
14
13
  #include <cmath>
14
+ #include <cstdio>
15
15
  #include <cstring>
16
16
  #include <memory>
17
17
 
18
+ #include <faiss/IndexPQ.h>
19
+ #include <faiss/impl/FaissAssert.h>
18
20
  #include <faiss/utils/distances.h>
19
21
  #include <faiss/utils/random.h>
20
22
  #include <faiss/utils/utils.h>
21
- #include <faiss/impl/FaissAssert.h>
22
- #include <faiss/IndexPQ.h>
23
23
 
24
24
  using namespace faiss;
25
25
 
26
-
27
26
  extern "C" {
28
27
 
29
28
  // this is to keep the clang syntax checker happy
@@ -31,134 +30,183 @@ extern "C" {
31
30
  #define FINTEGER int
32
31
  #endif
33
32
 
34
-
35
33
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
36
34
 
37
- int sgemm_ (
38
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
39
- n, FINTEGER *k, const float *alpha, const float *a,
40
- FINTEGER *lda, const float *b,
41
- FINTEGER *ldb, float *beta,
42
- float *c, FINTEGER *ldc);
43
-
44
- int dgemm_ (
45
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
46
- n, FINTEGER *k, const double *alpha, const double *a,
47
- FINTEGER *lda, const double *b,
48
- FINTEGER *ldb, double *beta,
49
- double *c, FINTEGER *ldc);
50
-
51
- int ssyrk_ (
52
- const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k,
53
- float *alpha, float *a, FINTEGER *lda,
54
- float *beta, float *c, FINTEGER *ldc);
35
+ int sgemm_(
36
+ const char* transa,
37
+ const char* transb,
38
+ FINTEGER* m,
39
+ FINTEGER* n,
40
+ FINTEGER* k,
41
+ const float* alpha,
42
+ const float* a,
43
+ FINTEGER* lda,
44
+ const float* b,
45
+ FINTEGER* ldb,
46
+ float* beta,
47
+ float* c,
48
+ FINTEGER* ldc);
49
+
50
+ int dgemm_(
51
+ const char* transa,
52
+ const char* transb,
53
+ FINTEGER* m,
54
+ FINTEGER* n,
55
+ FINTEGER* k,
56
+ const double* alpha,
57
+ const double* a,
58
+ FINTEGER* lda,
59
+ const double* b,
60
+ FINTEGER* ldb,
61
+ double* beta,
62
+ double* c,
63
+ FINTEGER* ldc);
64
+
65
+ int ssyrk_(
66
+ const char* uplo,
67
+ const char* trans,
68
+ FINTEGER* n,
69
+ FINTEGER* k,
70
+ float* alpha,
71
+ float* a,
72
+ FINTEGER* lda,
73
+ float* beta,
74
+ float* c,
75
+ FINTEGER* ldc);
55
76
 
56
77
  /* Lapack functions from http://www.netlib.org/clapack/old/single/ */
57
78
 
58
- int ssyev_ (
59
- const char *jobz, const char *uplo, FINTEGER *n, float *a,
60
- FINTEGER *lda, float *w, float *work, FINTEGER *lwork,
61
- FINTEGER *info);
62
-
63
- int dsyev_ (
64
- const char *jobz, const char *uplo, FINTEGER *n, double *a,
65
- FINTEGER *lda, double *w, double *work, FINTEGER *lwork,
66
- FINTEGER *info);
79
+ int ssyev_(
80
+ const char* jobz,
81
+ const char* uplo,
82
+ FINTEGER* n,
83
+ float* a,
84
+ FINTEGER* lda,
85
+ float* w,
86
+ float* work,
87
+ FINTEGER* lwork,
88
+ FINTEGER* info);
89
+
90
+ int dsyev_(
91
+ const char* jobz,
92
+ const char* uplo,
93
+ FINTEGER* n,
94
+ double* a,
95
+ FINTEGER* lda,
96
+ double* w,
97
+ double* work,
98
+ FINTEGER* lwork,
99
+ FINTEGER* info);
67
100
 
68
101
  int sgesvd_(
69
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
70
- float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt,
71
- FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info);
72
-
102
+ const char* jobu,
103
+ const char* jobvt,
104
+ FINTEGER* m,
105
+ FINTEGER* n,
106
+ float* a,
107
+ FINTEGER* lda,
108
+ float* s,
109
+ float* u,
110
+ FINTEGER* ldu,
111
+ float* vt,
112
+ FINTEGER* ldvt,
113
+ float* work,
114
+ FINTEGER* lwork,
115
+ FINTEGER* info);
73
116
 
74
117
  int dgesvd_(
75
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
76
- double *a, FINTEGER *lda, double *s, double *u, FINTEGER *ldu, double *vt,
77
- FINTEGER *ldvt, double *work, FINTEGER *lwork, FINTEGER *info);
78
-
118
+ const char* jobu,
119
+ const char* jobvt,
120
+ FINTEGER* m,
121
+ FINTEGER* n,
122
+ double* a,
123
+ FINTEGER* lda,
124
+ double* s,
125
+ double* u,
126
+ FINTEGER* ldu,
127
+ double* vt,
128
+ FINTEGER* ldvt,
129
+ double* work,
130
+ FINTEGER* lwork,
131
+ FINTEGER* info);
79
132
  }
80
133
 
81
134
  /*********************************************
82
135
  * VectorTransform
83
136
  *********************************************/
84
137
 
85
-
86
-
87
- float * VectorTransform::apply (Index::idx_t n, const float * x) const
88
- {
89
- float * xt = new float[n * d_out];
90
- apply_noalloc (n, x, xt);
138
+ float* VectorTransform::apply(Index::idx_t n, const float* x) const {
139
+ float* xt = new float[n * d_out];
140
+ apply_noalloc(n, x, xt);
91
141
  return xt;
92
142
  }
93
143
 
94
-
95
- void VectorTransform::train (idx_t, const float *) {
144
+ void VectorTransform::train(idx_t, const float*) {
96
145
  // does nothing by default
97
146
  }
98
147
 
99
-
100
- void VectorTransform::reverse_transform (
101
- idx_t , const float *,
102
- float *) const
103
- {
104
- FAISS_THROW_MSG ("reverse transform not implemented");
148
+ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
149
+ FAISS_THROW_MSG("reverse transform not implemented");
105
150
  }
106
151
 
107
-
108
-
109
-
110
152
  /*********************************************
111
153
  * LinearTransform
112
154
  *********************************************/
113
155
 
114
156
  /// both d_in > d_out and d_out < d_in are supported
115
- LinearTransform::LinearTransform (int d_in, int d_out,
116
- bool have_bias):
117
- VectorTransform (d_in, d_out), have_bias (have_bias),
118
- is_orthonormal (false), verbose (false)
119
- {
157
+ LinearTransform::LinearTransform(int d_in, int d_out, bool have_bias)
158
+ : VectorTransform(d_in, d_out),
159
+ have_bias(have_bias),
160
+ is_orthonormal(false),
161
+ verbose(false) {
120
162
  is_trained = false; // will be trained when A and b are initialized
121
163
  }
122
164
 
123
- void LinearTransform::apply_noalloc (Index::idx_t n, const float * x,
124
- float * xt) const
125
- {
165
+ void LinearTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
166
+ const {
126
167
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
127
168
 
128
169
  float c_factor;
129
170
  if (have_bias) {
130
- FAISS_THROW_IF_NOT_MSG (b.size() == d_out, "Bias not initialized");
131
- float * xi = xt;
171
+ FAISS_THROW_IF_NOT_MSG(b.size() == d_out, "Bias not initialized");
172
+ float* xi = xt;
132
173
  for (int i = 0; i < n; i++)
133
- for(int j = 0; j < d_out; j++)
174
+ for (int j = 0; j < d_out; j++)
134
175
  *xi++ = b[j];
135
176
  c_factor = 1.0;
136
177
  } else {
137
178
  c_factor = 0.0;
138
179
  }
139
180
 
140
- FAISS_THROW_IF_NOT_MSG (A.size() == d_out * d_in,
141
- "Transformation matrix not initialized");
181
+ FAISS_THROW_IF_NOT_MSG(
182
+ A.size() == d_out * d_in, "Transformation matrix not initialized");
142
183
 
143
184
  float one = 1;
144
185
  FINTEGER nbiti = d_out, ni = n, di = d_in;
145
- sgemm_ ("Transposed", "Not transposed",
146
- &nbiti, &ni, &di,
147
- &one, A.data(), &di, x, &di, &c_factor, xt, &nbiti);
148
-
186
+ sgemm_("Transposed",
187
+ "Not transposed",
188
+ &nbiti,
189
+ &ni,
190
+ &di,
191
+ &one,
192
+ A.data(),
193
+ &di,
194
+ x,
195
+ &di,
196
+ &c_factor,
197
+ xt,
198
+ &nbiti);
149
199
  }
150
200
 
151
-
152
- void LinearTransform::transform_transpose (idx_t n, const float * y,
153
- float *x) const
154
- {
201
+ void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
202
+ const {
155
203
  if (have_bias) { // allocate buffer to store bias-corrected data
156
- float *y_new = new float [n * d_out];
157
- const float *yr = y;
158
- float *yw = y_new;
204
+ float* y_new = new float[n * d_out];
205
+ const float* yr = y;
206
+ float* yw = y_new;
159
207
  for (idx_t i = 0; i < n; i++) {
160
208
  for (int j = 0; j < d_out; j++) {
161
- *yw++ = *yr++ - b [j];
209
+ *yw++ = *yr++ - b[j];
162
210
  }
163
211
  }
164
212
  y = y_new;
@@ -167,15 +215,26 @@ void LinearTransform::transform_transpose (idx_t n, const float * y,
167
215
  {
168
216
  FINTEGER dii = d_in, doi = d_out, ni = n;
169
217
  float one = 1.0, zero = 0.0;
170
- sgemm_ ("Not", "Not", &dii, &ni, &doi,
171
- &one, A.data (), &dii, y, &doi, &zero, x, &dii);
218
+ sgemm_("Not",
219
+ "Not",
220
+ &dii,
221
+ &ni,
222
+ &doi,
223
+ &one,
224
+ A.data(),
225
+ &dii,
226
+ y,
227
+ &doi,
228
+ &zero,
229
+ x,
230
+ &dii);
172
231
  }
173
232
 
174
- if (have_bias) delete [] y;
233
+ if (have_bias)
234
+ delete[] y;
175
235
  }
176
236
 
177
- void LinearTransform::set_is_orthonormal ()
178
- {
237
+ void LinearTransform::set_is_orthonormal() {
179
238
  if (d_out > d_in) {
180
239
  // not clear what we should do in this case
181
240
  is_orthonormal = false;
@@ -193,44 +252,53 @@ void LinearTransform::set_is_orthonormal ()
193
252
  FINTEGER dii = d_in, doi = d_out;
194
253
  float one = 1.0, zero = 0.0;
195
254
 
196
- sgemm_ ("Transposed", "Not", &doi, &doi, &dii,
197
- &one, A.data (), &dii,
198
- A.data(), &dii,
199
- &zero, ATA.data(), &doi);
255
+ sgemm_("Transposed",
256
+ "Not",
257
+ &doi,
258
+ &doi,
259
+ &dii,
260
+ &one,
261
+ A.data(),
262
+ &dii,
263
+ A.data(),
264
+ &dii,
265
+ &zero,
266
+ ATA.data(),
267
+ &doi);
200
268
 
201
269
  is_orthonormal = true;
202
270
  for (long i = 0; i < d_out; i++) {
203
271
  for (long j = 0; j < d_out; j++) {
204
272
  float v = ATA[i + j * d_out];
205
- if (i == j) v-= 1;
273
+ if (i == j)
274
+ v -= 1;
206
275
  if (fabs(v) > eps) {
207
276
  is_orthonormal = false;
208
277
  }
209
278
  }
210
279
  }
211
280
  }
212
-
213
281
  }
214
282
 
215
-
216
- void LinearTransform::reverse_transform (idx_t n, const float * xt,
217
- float *x) const
218
- {
283
+ void LinearTransform::reverse_transform(idx_t n, const float* xt, float* x)
284
+ const {
219
285
  if (is_orthonormal) {
220
- transform_transpose (n, xt, x);
286
+ transform_transpose(n, xt, x);
221
287
  } else {
222
- FAISS_THROW_MSG ("reverse transform not implemented for non-orthonormal matrices");
288
+ FAISS_THROW_MSG(
289
+ "reverse transform not implemented for non-orthonormal matrices");
223
290
  }
224
291
  }
225
292
 
226
-
227
- void LinearTransform::print_if_verbose (
228
- const char*name, const std::vector<double> &mat,
229
- int n, int d) const
230
- {
231
- if (!verbose) return;
293
+ void LinearTransform::print_if_verbose(
294
+ const char* name,
295
+ const std::vector<double>& mat,
296
+ int n,
297
+ int d) const {
298
+ if (!verbose)
299
+ return;
232
300
  printf("matrix %s: %d*%d [\n", name, n, d);
233
- FAISS_THROW_IF_NOT (mat.size() >= n * d);
301
+ FAISS_THROW_IF_NOT(mat.size() >= n * d);
234
302
  for (int i = 0; i < n; i++) {
235
303
  for (int j = 0; j < d; j++) {
236
304
  printf("%10.5g ", mat[i * d + j]);
@@ -244,24 +312,22 @@ void LinearTransform::print_if_verbose (
244
312
  * RandomRotationMatrix
245
313
  *********************************************/
246
314
 
247
- void RandomRotationMatrix::init (int seed)
248
- {
249
-
250
- if(d_out <= d_in) {
251
- A.resize (d_out * d_in);
252
- float *q = A.data();
315
+ void RandomRotationMatrix::init(int seed) {
316
+ if (d_out <= d_in) {
317
+ A.resize(d_out * d_in);
318
+ float* q = A.data();
253
319
  float_randn(q, d_out * d_in, seed);
254
320
  matrix_qr(d_in, d_out, q);
255
321
  } else {
256
322
  // use tight-frame transformation
257
- A.resize (d_out * d_out);
258
- float *q = A.data();
323
+ A.resize(d_out * d_out);
324
+ float* q = A.data();
259
325
  float_randn(q, d_out * d_out, seed);
260
326
  matrix_qr(d_out, d_out, q);
261
327
  // remove columns
262
328
  int i, j;
263
329
  for (i = 0; i < d_out; i++) {
264
- for(j = 0; j < d_in; j++) {
330
+ for (j = 0; j < d_in; j++) {
265
331
  q[i * d_in + j] = q[i * d_out + j];
266
332
  }
267
333
  }
@@ -271,247 +337,280 @@ void RandomRotationMatrix::init (int seed)
271
337
  is_trained = true;
272
338
  }
273
339
 
274
- void RandomRotationMatrix::train (Index::idx_t /*n*/, const float * /*x*/)
275
- {
340
+ void RandomRotationMatrix::train(Index::idx_t /*n*/, const float* /*x*/) {
276
341
  // initialize with some arbitrary seed
277
- init (12345);
342
+ init(12345);
278
343
  }
279
344
 
280
-
281
345
  /*********************************************
282
346
  * PCAMatrix
283
347
  *********************************************/
284
348
 
285
- PCAMatrix::PCAMatrix (int d_in, int d_out,
286
- float eigen_power, bool random_rotation):
287
- LinearTransform(d_in, d_out, true),
288
- eigen_power(eigen_power), random_rotation(random_rotation)
289
- {
349
+ PCAMatrix::PCAMatrix(
350
+ int d_in,
351
+ int d_out,
352
+ float eigen_power,
353
+ bool random_rotation)
354
+ : LinearTransform(d_in, d_out, true),
355
+ eigen_power(eigen_power),
356
+ random_rotation(random_rotation) {
290
357
  is_trained = false;
291
358
  max_points_per_d = 1000;
292
359
  balanced_bins = 0;
293
360
  }
294
361
 
295
-
296
362
  namespace {
297
363
 
298
364
  /// Compute the eigenvalue decomposition of symmetric matrix cov,
299
365
  /// dimensions d_in-by-d_in. Output eigenvectors in cov.
300
366
 
301
- void eig(size_t d_in, double *cov, double *eigenvalues, int verbose)
302
- {
367
+ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
303
368
  { // compute eigenvalues and vectors
304
369
  FINTEGER info = 0, lwork = -1, di = d_in;
305
370
  double workq;
306
371
 
307
- dsyev_ ("Vectors as well", "Upper",
308
- &di, cov, &di, eigenvalues, &workq, &lwork, &info);
372
+ dsyev_("Vectors as well",
373
+ "Upper",
374
+ &di,
375
+ cov,
376
+ &di,
377
+ eigenvalues,
378
+ &workq,
379
+ &lwork,
380
+ &info);
309
381
  lwork = FINTEGER(workq);
310
- double *work = new double[lwork];
382
+ double* work = new double[lwork];
311
383
 
312
- dsyev_ ("Vectors as well", "Upper",
313
- &di, cov, &di, eigenvalues, work, &lwork, &info);
384
+ dsyev_("Vectors as well",
385
+ "Upper",
386
+ &di,
387
+ cov,
388
+ &di,
389
+ eigenvalues,
390
+ work,
391
+ &lwork,
392
+ &info);
314
393
 
315
- delete [] work;
394
+ delete[] work;
316
395
 
317
396
  if (info != 0) {
318
- fprintf (stderr, "WARN ssyev info returns %d, "
319
- "a very bad PCA matrix is learnt\n",
320
- int(info));
397
+ fprintf(stderr,
398
+ "WARN ssyev info returns %d, "
399
+ "a very bad PCA matrix is learnt\n",
400
+ int(info));
321
401
  // do not throw exception, as the matrix could still be useful
322
402
  }
323
403
 
324
-
325
- if(verbose && d_in <= 10) {
404
+ if (verbose && d_in <= 10) {
326
405
  printf("info=%ld new eigvals=[", long(info));
327
- for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]);
406
+ for (int j = 0; j < d_in; j++)
407
+ printf("%g ", eigenvalues[j]);
328
408
  printf("]\n");
329
409
 
330
- double *ci = cov;
410
+ double* ci = cov;
331
411
  printf("eigenvecs=\n");
332
- for(int i = 0; i < d_in; i++) {
333
- for(int j = 0; j < d_in; j++)
412
+ for (int i = 0; i < d_in; i++) {
413
+ for (int j = 0; j < d_in; j++)
334
414
  printf("%10.4g ", *ci++);
335
415
  printf("\n");
336
416
  }
337
417
  }
338
-
339
418
  }
340
419
 
341
420
  // revert order of eigenvectors & values
342
421
 
343
- for(int i = 0; i < d_in / 2; i++) {
344
-
422
+ for (int i = 0; i < d_in / 2; i++) {
345
423
  std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
346
- double *v1 = cov + i * d_in;
347
- double *v2 = cov + (d_in - 1 - i) * d_in;
348
- for(int j = 0; j < d_in; j++)
424
+ double* v1 = cov + i * d_in;
425
+ double* v2 = cov + (d_in - 1 - i) * d_in;
426
+ for (int j = 0; j < d_in; j++)
349
427
  std::swap(v1[j], v2[j]);
350
428
  }
351
-
352
429
  }
353
430
 
431
+ } // namespace
354
432
 
355
- }
356
-
357
- void PCAMatrix::train (Index::idx_t n, const float *x)
358
- {
359
- const float * x_in = x;
433
+ void PCAMatrix::train(Index::idx_t n, const float* x) {
434
+ const float* x_in = x;
360
435
 
361
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
362
- max_points_per_d * d_in, x, verbose);
436
+ x = fvecs_maybe_subsample(
437
+ d_in, (size_t*)&n, max_points_per_d * d_in, x, verbose);
363
438
 
364
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
439
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
365
440
 
366
441
  // compute mean
367
- mean.clear(); mean.resize(d_in, 0.0);
442
+ mean.clear();
443
+ mean.resize(d_in, 0.0);
368
444
  if (have_bias) { // we may want to skip the bias
369
- const float *xi = x;
445
+ const float* xi = x;
370
446
  for (int i = 0; i < n; i++) {
371
- for(int j = 0; j < d_in; j++)
447
+ for (int j = 0; j < d_in; j++)
372
448
  mean[j] += *xi++;
373
449
  }
374
- for(int j = 0; j < d_in; j++)
450
+ for (int j = 0; j < d_in; j++)
375
451
  mean[j] /= n;
376
452
  }
377
- if(verbose) {
453
+ if (verbose) {
378
454
  printf("mean=[");
379
- for(int j = 0; j < d_in; j++) printf("%g ", mean[j]);
455
+ for (int j = 0; j < d_in; j++)
456
+ printf("%g ", mean[j]);
380
457
  printf("]\n");
381
458
  }
382
459
 
383
- if(n >= d_in) {
460
+ if (n >= d_in) {
384
461
  // compute covariance matrix, store it in PCA matrix
385
462
  PCAMat.resize(d_in * d_in);
386
- float * cov = PCAMat.data();
463
+ float* cov = PCAMat.data();
387
464
  { // initialize with mean * mean^T term
388
- float *ci = cov;
389
- for(int i = 0; i < d_in; i++) {
390
- for(int j = 0; j < d_in; j++)
391
- *ci++ = - n * mean[i] * mean[j];
465
+ float* ci = cov;
466
+ for (int i = 0; i < d_in; i++) {
467
+ for (int j = 0; j < d_in; j++)
468
+ *ci++ = -n * mean[i] * mean[j];
392
469
  }
393
470
  }
394
471
  {
395
472
  FINTEGER di = d_in, ni = n;
396
473
  float one = 1.0;
397
- ssyrk_ ("Up", "Non transposed",
398
- &di, &ni, &one, (float*)x, &di, &one, cov, &di);
399
-
474
+ ssyrk_("Up",
475
+ "Non transposed",
476
+ &di,
477
+ &ni,
478
+ &one,
479
+ (float*)x,
480
+ &di,
481
+ &one,
482
+ cov,
483
+ &di);
400
484
  }
401
- if(verbose && d_in <= 10) {
402
- float *ci = cov;
485
+ if (verbose && d_in <= 10) {
486
+ float* ci = cov;
403
487
  printf("cov=\n");
404
- for(int i = 0; i < d_in; i++) {
405
- for(int j = 0; j < d_in; j++)
488
+ for (int i = 0; i < d_in; i++) {
489
+ for (int j = 0; j < d_in; j++)
406
490
  printf("%10g ", *ci++);
407
491
  printf("\n");
408
492
  }
409
493
  }
410
494
 
411
- std::vector<double> covd (d_in * d_in);
412
- for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i];
495
+ std::vector<double> covd(d_in * d_in);
496
+ for (size_t i = 0; i < d_in * d_in; i++)
497
+ covd[i] = cov[i];
413
498
 
414
- std::vector<double> eigenvaluesd (d_in);
499
+ std::vector<double> eigenvaluesd(d_in);
415
500
 
416
- eig (d_in, covd.data (), eigenvaluesd.data (), verbose);
501
+ eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
417
502
 
418
- for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i];
419
- eigenvalues.resize (d_in);
503
+ for (size_t i = 0; i < d_in * d_in; i++)
504
+ PCAMat[i] = covd[i];
505
+ eigenvalues.resize(d_in);
420
506
 
421
507
  for (size_t i = 0; i < d_in; i++)
422
- eigenvalues [i] = eigenvaluesd [i];
423
-
508
+ eigenvalues[i] = eigenvaluesd[i];
424
509
 
425
510
  } else {
426
-
427
- std::vector<float> xc (n * d_in);
511
+ std::vector<float> xc(n * d_in);
428
512
 
429
513
  for (size_t i = 0; i < n; i++)
430
- for(size_t j = 0; j < d_in; j++)
431
- xc [i * d_in + j] = x [i * d_in + j] - mean[j];
514
+ for (size_t j = 0; j < d_in; j++)
515
+ xc[i * d_in + j] = x[i * d_in + j] - mean[j];
432
516
 
433
517
  // compute Gram matrix
434
- std::vector<float> gram (n * n);
518
+ std::vector<float> gram(n * n);
435
519
  {
436
520
  FINTEGER di = d_in, ni = n;
437
521
  float one = 1.0, zero = 0.0;
438
- ssyrk_ ("Up", "Transposed",
439
- &ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni);
522
+ ssyrk_("Up",
523
+ "Transposed",
524
+ &ni,
525
+ &di,
526
+ &one,
527
+ xc.data(),
528
+ &di,
529
+ &zero,
530
+ gram.data(),
531
+ &ni);
440
532
  }
441
533
 
442
- if(verbose && d_in <= 10) {
443
- float *ci = gram.data();
534
+ if (verbose && d_in <= 10) {
535
+ float* ci = gram.data();
444
536
  printf("gram=\n");
445
- for(int i = 0; i < n; i++) {
446
- for(int j = 0; j < n; j++)
537
+ for (int i = 0; i < n; i++) {
538
+ for (int j = 0; j < n; j++)
447
539
  printf("%10g ", *ci++);
448
540
  printf("\n");
449
541
  }
450
542
  }
451
543
 
452
- std::vector<double> gramd (n * n);
544
+ std::vector<double> gramd(n * n);
453
545
  for (size_t i = 0; i < n * n; i++)
454
- gramd [i] = gram [i];
546
+ gramd[i] = gram[i];
455
547
 
456
- std::vector<double> eigenvaluesd (n);
548
+ std::vector<double> eigenvaluesd(n);
457
549
 
458
550
  // eig will fill in only the n first eigenvals
459
551
 
460
- eig (n, gramd.data (), eigenvaluesd.data (), verbose);
552
+ eig(n, gramd.data(), eigenvaluesd.data(), verbose);
461
553
 
462
554
  PCAMat.resize(d_in * n);
463
555
 
464
556
  for (size_t i = 0; i < n * n; i++)
465
- gram [i] = gramd [i];
557
+ gram[i] = gramd[i];
466
558
 
467
- eigenvalues.resize (d_in);
559
+ eigenvalues.resize(d_in);
468
560
  // fill in only the n first ones
469
561
  for (size_t i = 0; i < n; i++)
470
- eigenvalues [i] = eigenvaluesd [i];
562
+ eigenvalues[i] = eigenvaluesd[i];
471
563
 
472
564
  { // compute PCAMat = x' * v
473
565
  FINTEGER di = d_in, ni = n;
474
566
  float one = 1.0;
475
567
 
476
- sgemm_ ("Non", "Non Trans",
477
- &di, &ni, &ni,
478
- &one, xc.data(), &di, gram.data(), &ni,
479
- &one, PCAMat.data(), &di);
568
+ sgemm_("Non",
569
+ "Non Trans",
570
+ &di,
571
+ &ni,
572
+ &ni,
573
+ &one,
574
+ xc.data(),
575
+ &di,
576
+ gram.data(),
577
+ &ni,
578
+ &one,
579
+ PCAMat.data(),
580
+ &di);
480
581
  }
481
582
 
482
- if(verbose && d_in <= 10) {
483
- float *ci = PCAMat.data();
583
+ if (verbose && d_in <= 10) {
584
+ float* ci = PCAMat.data();
484
585
  printf("PCAMat=\n");
485
- for(int i = 0; i < n; i++) {
486
- for(int j = 0; j < d_in; j++)
586
+ for (int i = 0; i < n; i++) {
587
+ for (int j = 0; j < d_in; j++)
487
588
  printf("%10g ", *ci++);
488
589
  printf("\n");
489
590
  }
490
591
  }
491
- fvec_renorm_L2 (d_in, n, PCAMat.data());
492
-
592
+ fvec_renorm_L2(d_in, n, PCAMat.data());
493
593
  }
494
594
 
495
595
  prepare_Ab();
496
596
  is_trained = true;
497
597
  }
498
598
 
499
- void PCAMatrix::copy_from (const PCAMatrix & other)
500
- {
501
- FAISS_THROW_IF_NOT (other.is_trained);
599
+ void PCAMatrix::copy_from(const PCAMatrix& other) {
600
+ FAISS_THROW_IF_NOT(other.is_trained);
502
601
  mean = other.mean;
503
602
  eigenvalues = other.eigenvalues;
504
603
  PCAMat = other.PCAMat;
505
- prepare_Ab ();
604
+ prepare_Ab();
506
605
  is_trained = true;
507
606
  }
508
607
 
509
- void PCAMatrix::prepare_Ab ()
510
- {
511
- FAISS_THROW_IF_NOT_FMT (
608
+ void PCAMatrix::prepare_Ab() {
609
+ FAISS_THROW_IF_NOT_FMT(
512
610
  d_out * d_in <= PCAMat.size(),
513
611
  "PCA matrix cannot output %d dimensions from %d ",
514
- d_out, d_in);
612
+ d_out,
613
+ d_in);
515
614
 
516
615
  if (!random_rotation) {
517
616
  A = PCAMat;
@@ -519,23 +618,23 @@ void PCAMatrix::prepare_Ab ()
519
618
 
520
619
  // first scale the components
521
620
  if (eigen_power != 0) {
522
- float *ai = A.data();
621
+ float* ai = A.data();
523
622
  for (int i = 0; i < d_out; i++) {
524
623
  float factor = pow(eigenvalues[i], eigen_power);
525
- for(int j = 0; j < d_in; j++)
624
+ for (int j = 0; j < d_in; j++)
526
625
  *ai++ *= factor;
527
626
  }
528
627
  }
529
628
 
530
629
  if (balanced_bins != 0) {
531
- FAISS_THROW_IF_NOT (d_out % balanced_bins == 0);
630
+ FAISS_THROW_IF_NOT(d_out % balanced_bins == 0);
532
631
  int dsub = d_out / balanced_bins;
533
- std::vector <float> Ain;
632
+ std::vector<float> Ain;
534
633
  std::swap(A, Ain);
535
634
  A.resize(d_out * d_in);
536
635
 
537
- std::vector <float> accu(balanced_bins);
538
- std::vector <int> counter(balanced_bins);
636
+ std::vector<float> accu(balanced_bins);
637
+ std::vector<int> counter(balanced_bins);
539
638
 
540
639
  // greedy assignment
541
640
  for (int i = 0; i < d_out; i++) {
@@ -550,9 +649,8 @@ void PCAMatrix::prepare_Ab ()
550
649
  }
551
650
  int row_dst = best_j * dsub + counter[best_j];
552
651
  accu[best_j] += eigenvalues[i];
553
- counter[best_j] ++;
554
- memcpy (&A[row_dst * d_in], &Ain[i * d_in],
555
- d_in * sizeof (A[0]));
652
+ counter[best_j]++;
653
+ memcpy(&A[row_dst * d_in], &Ain[i * d_in], d_in * sizeof(A[0]));
556
654
  }
557
655
 
558
656
  if (verbose) {
@@ -563,11 +661,11 @@ void PCAMatrix::prepare_Ab ()
563
661
  }
564
662
  }
565
663
 
566
-
567
664
  } else {
568
- FAISS_THROW_IF_NOT_MSG (balanced_bins == 0,
569
- "both balancing bins and applying a random rotation "
570
- "does not make sense");
665
+ FAISS_THROW_IF_NOT_MSG(
666
+ balanced_bins == 0,
667
+ "both balancing bins and applying a random rotation "
668
+ "does not make sense");
571
669
  RandomRotationMatrix rr(d_out, d_out);
572
670
 
573
671
  rr.init(5);
@@ -576,8 +674,8 @@ void PCAMatrix::prepare_Ab ()
576
674
  if (eigen_power != 0) {
577
675
  for (int i = 0; i < d_out; i++) {
578
676
  float factor = pow(eigenvalues[i], eigen_power);
579
- for(int j = 0; j < d_out; j++)
580
- rr.A[j * d_out + i] *= factor;
677
+ for (int j = 0; j < d_out; j++)
678
+ rr.A[j * d_out + i] *= factor;
581
679
  }
582
680
  }
583
681
 
@@ -586,15 +684,24 @@ void PCAMatrix::prepare_Ab ()
586
684
  FINTEGER dii = d_in, doo = d_out;
587
685
  float one = 1.0, zero = 0.0;
588
686
 
589
- sgemm_ ("Not", "Not", &dii, &doo, &doo,
590
- &one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero,
591
- A.data(), &dii);
592
-
687
+ sgemm_("Not",
688
+ "Not",
689
+ &dii,
690
+ &doo,
691
+ &doo,
692
+ &one,
693
+ PCAMat.data(),
694
+ &dii,
695
+ rr.A.data(),
696
+ &doo,
697
+ &zero,
698
+ A.data(),
699
+ &dii);
593
700
  }
594
-
595
701
  }
596
702
 
597
- b.clear(); b.resize(d_out);
703
+ b.clear();
704
+ b.resize(d_out);
598
705
 
599
706
  for (int i = 0; i < d_out; i++) {
600
707
  float accu = 0;
@@ -604,57 +711,61 @@ void PCAMatrix::prepare_Ab ()
604
711
  }
605
712
 
606
713
  is_orthonormal = eigen_power == 0;
607
-
608
714
  }
609
715
 
610
716
  /*********************************************
611
717
  * ITQMatrix
612
718
  *********************************************/
613
719
 
614
- ITQMatrix::ITQMatrix (int d):
615
- LinearTransform(d, d, false),
616
- max_iter (50),
617
- seed (123)
618
- {
619
- }
620
-
720
+ ITQMatrix::ITQMatrix(int d)
721
+ : LinearTransform(d, d, false), max_iter(50), seed(123) {}
621
722
 
622
723
  /** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */
623
- void ITQMatrix::train (Index::idx_t n, const float* xf)
624
- {
724
+ void ITQMatrix::train(Index::idx_t n, const float* xf) {
625
725
  size_t d = d_in;
626
- std::vector<double> rotation (d * d);
726
+ std::vector<double> rotation(d * d);
627
727
 
628
728
  if (init_rotation.size() == d * d) {
629
- memcpy (rotation.data(), init_rotation.data(),
630
- d * d * sizeof(rotation[0]));
729
+ memcpy(rotation.data(),
730
+ init_rotation.data(),
731
+ d * d * sizeof(rotation[0]));
631
732
  } else {
632
- RandomRotationMatrix rrot (d, d);
633
- rrot.init (seed);
733
+ RandomRotationMatrix rrot(d, d);
734
+ rrot.init(seed);
634
735
  for (size_t i = 0; i < d * d; i++) {
635
736
  rotation[i] = rrot.A[i];
636
737
  }
637
738
  }
638
739
 
639
- std::vector<double> x (n * d);
740
+ std::vector<double> x(n * d);
640
741
 
641
742
  for (size_t i = 0; i < n * d; i++) {
642
743
  x[i] = xf[i];
643
744
  }
644
745
 
645
- std::vector<double> rotated_x (n * d), cov_mat (d * d);
646
- std::vector<double> u (d * d), vt (d * d), singvals (d);
746
+ std::vector<double> rotated_x(n * d), cov_mat(d * d);
747
+ std::vector<double> u(d * d), vt(d * d), singvals(d);
647
748
 
648
749
  for (int i = 0; i < max_iter; i++) {
649
- print_if_verbose ("rotation", rotation, d, d);
750
+ print_if_verbose("rotation", rotation, d, d);
650
751
  { // rotated_data = np.dot(training_data, rotation)
651
752
  FINTEGER di = d, ni = n;
652
753
  double one = 1, zero = 0;
653
- dgemm_ ("N", "N", &di, &ni, &di,
654
- &one, rotation.data(), &di, x.data(), &di,
655
- &zero, rotated_x.data(), &di);
754
+ dgemm_("N",
755
+ "N",
756
+ &di,
757
+ &ni,
758
+ &di,
759
+ &one,
760
+ rotation.data(),
761
+ &di,
762
+ x.data(),
763
+ &di,
764
+ &zero,
765
+ rotated_x.data(),
766
+ &di);
656
767
  }
657
- print_if_verbose ("rotated_x", rotated_x, n, d);
768
+ print_if_verbose("rotated_x", rotated_x, n, d);
658
769
  // binarize
659
770
  for (size_t j = 0; j < n * d; j++) {
660
771
  rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
@@ -663,88 +774,119 @@ void ITQMatrix::train (Index::idx_t n, const float* xf)
663
774
  { // rotated_data = np.dot(training_data, rotation)
664
775
  FINTEGER di = d, ni = n;
665
776
  double one = 1, zero = 0;
666
- dgemm_ ("N", "T", &di, &di, &ni,
667
- &one, rotated_x.data(), &di, x.data(), &di,
668
- &zero, cov_mat.data(), &di);
777
+ dgemm_("N",
778
+ "T",
779
+ &di,
780
+ &di,
781
+ &ni,
782
+ &one,
783
+ rotated_x.data(),
784
+ &di,
785
+ x.data(),
786
+ &di,
787
+ &zero,
788
+ cov_mat.data(),
789
+ &di);
669
790
  }
670
- print_if_verbose ("cov_mat", cov_mat, d, d);
791
+ print_if_verbose("cov_mat", cov_mat, d, d);
671
792
  // SVD
672
793
  {
673
-
674
794
  FINTEGER di = d;
675
795
  FINTEGER lwork = -1, info;
676
796
  double lwork1;
677
797
 
678
798
  // workspace query
679
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
680
- singvals.data(), u.data(), &di,
681
- vt.data(), &di,
682
- &lwork1, &lwork, &info);
683
-
684
- FAISS_THROW_IF_NOT (info == 0);
685
- lwork = size_t (lwork1);
686
- std::vector<double> work (lwork);
687
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
688
- singvals.data(), u.data(), &di,
689
- vt.data(), &di,
690
- work.data(), &lwork, &info);
691
- FAISS_THROW_IF_NOT_FMT (info == 0, "sgesvd returned info=%d", info);
692
-
799
+ dgesvd_("A",
800
+ "A",
801
+ &di,
802
+ &di,
803
+ cov_mat.data(),
804
+ &di,
805
+ singvals.data(),
806
+ u.data(),
807
+ &di,
808
+ vt.data(),
809
+ &di,
810
+ &lwork1,
811
+ &lwork,
812
+ &info);
813
+
814
+ FAISS_THROW_IF_NOT(info == 0);
815
+ lwork = size_t(lwork1);
816
+ std::vector<double> work(lwork);
817
+ dgesvd_("A",
818
+ "A",
819
+ &di,
820
+ &di,
821
+ cov_mat.data(),
822
+ &di,
823
+ singvals.data(),
824
+ u.data(),
825
+ &di,
826
+ vt.data(),
827
+ &di,
828
+ work.data(),
829
+ &lwork,
830
+ &info);
831
+ FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
693
832
  }
694
- print_if_verbose ("u", u, d, d);
695
- print_if_verbose ("vt", vt, d, d);
833
+ print_if_verbose("u", u, d, d);
834
+ print_if_verbose("vt", vt, d, d);
696
835
  // update rotation
697
836
  {
698
837
  FINTEGER di = d;
699
838
  double one = 1, zero = 0;
700
- dgemm_ ("N", "T", &di, &di, &di,
701
- &one, u.data(), &di, vt.data(), &di,
702
- &zero, rotation.data(), &di);
839
+ dgemm_("N",
840
+ "T",
841
+ &di,
842
+ &di,
843
+ &di,
844
+ &one,
845
+ u.data(),
846
+ &di,
847
+ vt.data(),
848
+ &di,
849
+ &zero,
850
+ rotation.data(),
851
+ &di);
703
852
  }
704
- print_if_verbose ("final rot", rotation, d, d);
705
-
853
+ print_if_verbose("final rot", rotation, d, d);
706
854
  }
707
- A.resize (d * d);
855
+ A.resize(d * d);
708
856
  for (size_t i = 0; i < d; i++) {
709
857
  for (size_t j = 0; j < d; j++) {
710
858
  A[i + d * j] = rotation[j + d * i];
711
859
  }
712
860
  }
713
861
  is_trained = true;
714
-
715
862
  }
716
863
 
717
- ITQTransform::ITQTransform (int d_in, int d_out, bool do_pca):
718
- VectorTransform (d_in, d_out),
719
- do_pca (do_pca),
720
- itq (d_out),
721
- pca_then_itq (d_in, d_out, false)
722
- {
864
+ ITQTransform::ITQTransform(int d_in, int d_out, bool do_pca)
865
+ : VectorTransform(d_in, d_out),
866
+ do_pca(do_pca),
867
+ itq(d_out),
868
+ pca_then_itq(d_in, d_out, false) {
723
869
  if (!do_pca) {
724
- FAISS_THROW_IF_NOT (d_in == d_out);
870
+ FAISS_THROW_IF_NOT(d_in == d_out);
725
871
  }
726
872
  max_train_per_dim = 10;
727
873
  is_trained = false;
728
874
  }
729
875
 
876
+ void ITQTransform::train(idx_t n, const float* x) {
877
+ FAISS_THROW_IF_NOT(!is_trained);
730
878
 
731
-
732
-
733
- void ITQTransform::train (idx_t n, const float *x)
734
- {
735
- FAISS_THROW_IF_NOT (!is_trained);
736
-
737
- const float * x_in = x;
879
+ const float* x_in = x;
738
880
  size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
739
- x = fvecs_maybe_subsample (d_in, (size_t*)&n, max_train_points, x);
881
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x);
740
882
 
741
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
883
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
742
884
 
743
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
885
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
744
886
  { // normalize
745
887
  int d = d_in;
746
888
 
747
- mean.resize (d, 0);
889
+ mean.resize(d, 0);
748
890
  for (idx_t i = 0; i < n; i++) {
749
891
  for (idx_t j = 0; j < d; j++) {
750
892
  mean[j] += x[i * d + j];
@@ -755,38 +897,47 @@ void ITQTransform::train (idx_t n, const float *x)
755
897
  }
756
898
  for (idx_t i = 0; i < n; i++) {
757
899
  for (idx_t j = 0; j < d; j++) {
758
- x_norm[i * d + j] = x[i * d + j] - mean[j];
900
+ x_norm[i * d + j] = x[i * d + j] - mean[j];
759
901
  }
760
902
  }
761
- fvec_renorm_L2 (d_in, n, x_norm.get());
903
+ fvec_renorm_L2(d_in, n, x_norm.get());
762
904
  }
763
905
 
764
906
  // train PCA
765
907
 
766
- PCAMatrix pca (d_in, d_out);
767
- float *x_pca;
768
- std::unique_ptr<float []> x_pca_del;
908
+ PCAMatrix pca(d_in, d_out);
909
+ float* x_pca;
910
+ std::unique_ptr<float[]> x_pca_del;
769
911
  if (do_pca) {
770
- pca.have_bias = false; // for consistency with reference implem
771
- pca.train (n, x_norm.get());
772
- x_pca = pca.apply (n, x_norm.get());
912
+ pca.have_bias = false; // for consistency with reference implem
913
+ pca.train(n, x_norm.get());
914
+ x_pca = pca.apply(n, x_norm.get());
773
915
  x_pca_del.reset(x_pca);
774
916
  } else {
775
917
  x_pca = x_norm.get();
776
918
  }
777
919
 
778
920
  // train ITQ
779
- itq.train (n, x_pca);
921
+ itq.train(n, x_pca);
780
922
 
781
923
  // merge PCA and ITQ
782
924
  if (do_pca) {
783
925
  FINTEGER di = d_out, dini = d_in;
784
926
  float one = 1, zero = 0;
785
927
  pca_then_itq.A.resize(d_in * d_out);
786
- sgemm_ ("N", "N", &dini, &di, &di,
787
- &one, pca.A.data(), &dini,
788
- itq.A.data(), &di,
789
- &zero, pca_then_itq.A.data(), &dini);
928
+ sgemm_("N",
929
+ "N",
930
+ &dini,
931
+ &di,
932
+ &di,
933
+ &one,
934
+ pca.A.data(),
935
+ &dini,
936
+ itq.A.data(),
937
+ &di,
938
+ &zero,
939
+ pca_then_itq.A.data(),
940
+ &dini);
790
941
  } else {
791
942
  pca_then_itq.A = itq.A;
792
943
  }
@@ -794,12 +945,11 @@ void ITQTransform::train (idx_t n, const float *x)
794
945
  is_trained = true;
795
946
  }
796
947
 
797
- void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
798
- float * xt) const
799
- {
948
+ void ITQTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
949
+ const {
800
950
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
801
951
 
802
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
952
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
803
953
  { // normalize
804
954
  int d = d_in;
805
955
  for (idx_t i = 0; i < n; i++) {
@@ -809,41 +959,36 @@ void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
809
959
  }
810
960
  // this is not really useful if we are going to binarize right
811
961
  // afterwards but OK
812
- fvec_renorm_L2 (d_in, n, x_norm.get());
962
+ fvec_renorm_L2(d_in, n, x_norm.get());
813
963
  }
814
964
 
815
- pca_then_itq.apply_noalloc (n, x_norm.get(), xt);
965
+ pca_then_itq.apply_noalloc(n, x_norm.get(), xt);
816
966
  }
817
967
 
818
968
  /*********************************************
819
969
  * OPQMatrix
820
970
  *********************************************/
821
971
 
822
-
823
- OPQMatrix::OPQMatrix (int d, int M, int d2):
824
- LinearTransform (d, d2 == -1 ? d : d2, false), M(M),
825
- niter (50),
826
- niter_pq (4), niter_pq_0 (40),
827
- verbose(false),
828
- pq(nullptr)
829
- {
972
+ OPQMatrix::OPQMatrix(int d, int M, int d2)
973
+ : LinearTransform(d, d2 == -1 ? d : d2, false),
974
+ M(M),
975
+ niter(50),
976
+ niter_pq(4),
977
+ niter_pq_0(40),
978
+ verbose(false),
979
+ pq(nullptr) {
830
980
  is_trained = false;
831
981
  // OPQ is quite expensive to train, so set this right.
832
982
  max_train_points = 256 * 256;
833
983
  pq = nullptr;
834
984
  }
835
985
 
986
+ void OPQMatrix::train(Index::idx_t n, const float* x) {
987
+ const float* x_in = x;
836
988
 
989
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x, verbose);
837
990
 
838
- void OPQMatrix::train (Index::idx_t n, const float *x)
839
- {
840
-
841
- const float * x_in = x;
842
-
843
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
844
- max_train_points, x, verbose);
845
-
846
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
991
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
847
992
 
848
993
  // To support d_out > d_in, we pad input vectors with 0s to d_out
849
994
  size_t d = d_out <= d_in ? d_in : d_out;
@@ -867,22 +1012,26 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
867
1012
  #endif
868
1013
 
869
1014
  if (verbose) {
870
- printf ("OPQMatrix::train: training an OPQ rotation matrix "
871
- "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
872
- M, n, d_in, d_out);
1015
+ printf("OPQMatrix::train: training an OPQ rotation matrix "
1016
+ "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
1017
+ M,
1018
+ n,
1019
+ d_in,
1020
+ d_out);
873
1021
  }
874
1022
 
875
- std::vector<float> xtrain (n * d);
1023
+ std::vector<float> xtrain(n * d);
876
1024
  // center x
877
1025
  {
878
- std::vector<float> sum (d);
879
- const float *xi = x;
1026
+ std::vector<float> sum(d);
1027
+ const float* xi = x;
880
1028
  for (size_t i = 0; i < n; i++) {
881
1029
  for (int j = 0; j < d_in; j++)
882
- sum [j] += *xi++;
1030
+ sum[j] += *xi++;
883
1031
  }
884
- for (int i = 0; i < d; i++) sum[i] /= n;
885
- float *yi = xtrain.data();
1032
+ for (int i = 0; i < d; i++)
1033
+ sum[i] /= n;
1034
+ float* yi = xtrain.data();
886
1035
  xi = x;
887
1036
  for (size_t i = 0; i < n; i++) {
888
1037
  for (int j = 0; j < d_in; j++)
@@ -890,71 +1039,80 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
890
1039
  yi += d - d_in;
891
1040
  }
892
1041
  }
893
- float *rotation;
1042
+ float* rotation;
894
1043
 
895
- if (A.size () == 0) {
896
- A.resize (d * d);
1044
+ if (A.size() == 0) {
1045
+ A.resize(d * d);
897
1046
  rotation = A.data();
898
1047
  if (verbose)
899
1048
  printf(" OPQMatrix::train: making random %zd*%zd rotation\n",
900
- d, d);
901
- float_randn (rotation, d * d, 1234);
902
- matrix_qr (d, d, rotation);
1049
+ d,
1050
+ d);
1051
+ float_randn(rotation, d * d, 1234);
1052
+ matrix_qr(d, d, rotation);
903
1053
  // we use only the d * d2 upper part of the matrix
904
- A.resize (d * d2);
1054
+ A.resize(d * d2);
905
1055
  } else {
906
- FAISS_THROW_IF_NOT (A.size() == d * d2);
1056
+ FAISS_THROW_IF_NOT(A.size() == d * d2);
907
1057
  rotation = A.data();
908
1058
  }
909
1059
 
910
- std::vector<float>
911
- xproj (d2 * n), pq_recons (d2 * n), xxr (d * n),
912
- tmp(d * d * 4);
913
-
1060
+ std::vector<float> xproj(d2 * n), pq_recons(d2 * n), xxr(d * n),
1061
+ tmp(d * d * 4);
914
1062
 
915
- ProductQuantizer pq_default (d2, M, 8);
916
- ProductQuantizer &pq_regular = pq ? *pq : pq_default;
917
- std::vector<uint8_t> codes (pq_regular.code_size * n);
1063
+ ProductQuantizer pq_default(d2, M, 8);
1064
+ ProductQuantizer& pq_regular = pq ? *pq : pq_default;
1065
+ std::vector<uint8_t> codes(pq_regular.code_size * n);
918
1066
 
919
1067
  double t0 = getmillisecs();
920
1068
  for (int iter = 0; iter < niter; iter++) {
921
-
922
1069
  { // torch.mm(xtrain, rotation:t())
923
1070
  FINTEGER di = d, d2i = d2, ni = n;
924
1071
  float zero = 0, one = 1;
925
- sgemm_ ("Transposed", "Not transposed",
926
- &d2i, &ni, &di,
927
- &one, rotation, &di,
928
- xtrain.data(), &di,
929
- &zero, xproj.data(), &d2i);
1072
+ sgemm_("Transposed",
1073
+ "Not transposed",
1074
+ &d2i,
1075
+ &ni,
1076
+ &di,
1077
+ &one,
1078
+ rotation,
1079
+ &di,
1080
+ xtrain.data(),
1081
+ &di,
1082
+ &zero,
1083
+ xproj.data(),
1084
+ &d2i);
930
1085
  }
931
1086
 
932
1087
  pq_regular.cp.max_points_per_centroid = 1000;
933
1088
  pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
934
1089
  pq_regular.verbose = verbose;
935
- pq_regular.train (n, xproj.data());
1090
+ pq_regular.train(n, xproj.data());
936
1091
 
937
1092
  if (verbose) {
938
1093
  printf(" encode / decode\n");
939
1094
  }
940
1095
  if (pq_regular.assign_index) {
941
- pq_regular.compute_codes_with_assign_index
942
- (xproj.data(), codes.data(), n);
1096
+ pq_regular.compute_codes_with_assign_index(
1097
+ xproj.data(), codes.data(), n);
943
1098
  } else {
944
- pq_regular.compute_codes (xproj.data(), codes.data(), n);
1099
+ pq_regular.compute_codes(xproj.data(), codes.data(), n);
945
1100
  }
946
- pq_regular.decode (codes.data(), pq_recons.data(), n);
1101
+ pq_regular.decode(codes.data(), pq_recons.data(), n);
947
1102
 
948
- float pq_err = fvec_L2sqr (pq_recons.data(), xproj.data(), n * d2) / n;
1103
+ float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
949
1104
 
950
1105
  if (verbose)
951
- printf (" Iteration %d (%d PQ iterations):"
952
- "%.3f s, obj=%g\n", iter, pq_regular.cp.niter,
953
- (getmillisecs () - t0) / 1000.0, pq_err);
1106
+ printf(" Iteration %d (%d PQ iterations):"
1107
+ "%.3f s, obj=%g\n",
1108
+ iter,
1109
+ pq_regular.cp.niter,
1110
+ (getmillisecs() - t0) / 1000.0,
1111
+ pq_err);
954
1112
 
955
1113
  {
956
- float *u = tmp.data(), *vt = &tmp [d * d];
957
- float *sing_val = &tmp [2 * d * d];
1114
+ float *u = tmp.data(), *vt = &tmp[d * d];
1115
+ float* sing_val = &tmp[2 * d * d];
958
1116
  FINTEGER di = d, d2i = d2, ni = n;
959
1117
  float one = 1, zero = 0;
960
1118
 
@@ -962,36 +1120,69 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
962
1120
  printf(" X * recons\n");
963
1121
  }
964
1122
  // torch.mm(xtrain:t(), pq_recons)
965
- sgemm_ ("Not", "Transposed",
966
- &d2i, &di, &ni,
967
- &one, pq_recons.data(), &d2i,
968
- xtrain.data(), &di,
969
- &zero, xxr.data(), &d2i);
970
-
1123
+ sgemm_("Not",
1124
+ "Transposed",
1125
+ &d2i,
1126
+ &di,
1127
+ &ni,
1128
+ &one,
1129
+ pq_recons.data(),
1130
+ &d2i,
1131
+ xtrain.data(),
1132
+ &di,
1133
+ &zero,
1134
+ xxr.data(),
1135
+ &d2i);
971
1136
 
972
1137
  FINTEGER lwork = -1, info = -1;
973
1138
  float worksz;
974
1139
  // workspace query
975
- sgesvd_ ("All", "All",
976
- &d2i, &di, xxr.data(), &d2i,
977
- sing_val,
978
- vt, &d2i, u, &di,
979
- &worksz, &lwork, &info);
1140
+ sgesvd_("All",
1141
+ "All",
1142
+ &d2i,
1143
+ &di,
1144
+ xxr.data(),
1145
+ &d2i,
1146
+ sing_val,
1147
+ vt,
1148
+ &d2i,
1149
+ u,
1150
+ &di,
1151
+ &worksz,
1152
+ &lwork,
1153
+ &info);
980
1154
 
981
1155
  lwork = int(worksz);
982
- std::vector<float> work (lwork);
1156
+ std::vector<float> work(lwork);
983
1157
  // u and vt swapped
984
- sgesvd_ ("All", "All",
985
- &d2i, &di, xxr.data(), &d2i,
986
- sing_val,
987
- vt, &d2i, u, &di,
988
- work.data(), &lwork, &info);
989
-
990
- sgemm_ ("Transposed", "Transposed",
991
- &di, &d2i, &d2i,
992
- &one, u, &di, vt, &d2i,
993
- &zero, rotation, &di);
994
-
1158
+ sgesvd_("All",
1159
+ "All",
1160
+ &d2i,
1161
+ &di,
1162
+ xxr.data(),
1163
+ &d2i,
1164
+ sing_val,
1165
+ vt,
1166
+ &d2i,
1167
+ u,
1168
+ &di,
1169
+ work.data(),
1170
+ &lwork,
1171
+ &info);
1172
+
1173
+ sgemm_("Transposed",
1174
+ "Transposed",
1175
+ &di,
1176
+ &d2i,
1177
+ &d2i,
1178
+ &one,
1179
+ u,
1180
+ &di,
1181
+ vt,
1182
+ &d2i,
1183
+ &zero,
1184
+ rotation,
1185
+ &di);
995
1186
  }
996
1187
  pq_regular.train_type = ProductQuantizer::Train_hot_start;
997
1188
  }
@@ -999,59 +1190,52 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
999
1190
  // revert A matrix
1000
1191
  if (d > d_in) {
1001
1192
  for (long i = 0; i < d_out; i++)
1002
- memmove (&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1003
- A.resize (d_in * d_out);
1193
+ memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1194
+ A.resize(d_in * d_out);
1004
1195
  }
1005
1196
 
1006
1197
  is_trained = true;
1007
1198
  is_orthonormal = true;
1008
1199
  }
1009
1200
 
1010
-
1011
1201
  /*********************************************
1012
1202
  * NormalizationTransform
1013
1203
  *********************************************/
1014
1204
 
1015
- NormalizationTransform::NormalizationTransform (int d, float norm):
1016
- VectorTransform (d, d), norm (norm)
1017
- {
1018
- }
1205
+ NormalizationTransform::NormalizationTransform(int d, float norm)
1206
+ : VectorTransform(d, d), norm(norm) {}
1019
1207
 
1020
- NormalizationTransform::NormalizationTransform ():
1021
- VectorTransform (-1, -1), norm (-1)
1022
- {
1023
- }
1208
+ NormalizationTransform::NormalizationTransform()
1209
+ : VectorTransform(-1, -1), norm(-1) {}
1024
1210
 
1025
- void NormalizationTransform::apply_noalloc
1026
- (idx_t n, const float* x, float* xt) const
1027
- {
1211
+ void NormalizationTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1212
+ const {
1028
1213
  if (norm == 2.0) {
1029
- memcpy (xt, x, sizeof (x[0]) * n * d_in);
1030
- fvec_renorm_L2 (d_in, n, xt);
1214
+ memcpy(xt, x, sizeof(x[0]) * n * d_in);
1215
+ fvec_renorm_L2(d_in, n, xt);
1031
1216
  } else {
1032
- FAISS_THROW_MSG ("not implemented");
1217
+ FAISS_THROW_MSG("not implemented");
1033
1218
  }
1034
1219
  }
1035
1220
 
1036
- void NormalizationTransform::reverse_transform (idx_t n, const float* xt,
1037
- float* x) const
1038
- {
1039
- memcpy (x, xt, sizeof (xt[0]) * n * d_in);
1221
+ void NormalizationTransform::reverse_transform(
1222
+ idx_t n,
1223
+ const float* xt,
1224
+ float* x) const {
1225
+ memcpy(x, xt, sizeof(xt[0]) * n * d_in);
1040
1226
  }
1041
1227
 
1042
1228
  /*********************************************
1043
1229
  * CenteringTransform
1044
1230
  *********************************************/
1045
1231
 
1046
- CenteringTransform::CenteringTransform (int d):
1047
- VectorTransform (d, d)
1048
- {
1232
+ CenteringTransform::CenteringTransform(int d) : VectorTransform(d, d) {
1049
1233
  is_trained = false;
1050
1234
  }
1051
1235
 
1052
- void CenteringTransform::train(Index::idx_t n, const float *x) {
1236
+ void CenteringTransform::train(Index::idx_t n, const float* x) {
1053
1237
  FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
1054
- mean.resize (d_in, 0);
1238
+ mean.resize(d_in, 0);
1055
1239
  for (idx_t i = 0; i < n; i++) {
1056
1240
  for (size_t j = 0; j < d_in; j++) {
1057
1241
  mean[j] += *x++;
@@ -1064,11 +1248,9 @@ void CenteringTransform::train(Index::idx_t n, const float *x) {
1064
1248
  is_trained = true;
1065
1249
  }
1066
1250
 
1067
-
1068
- void CenteringTransform::apply_noalloc
1069
- (idx_t n, const float* x, float* xt) const
1070
- {
1071
- FAISS_THROW_IF_NOT (is_trained);
1251
+ void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1252
+ const {
1253
+ FAISS_THROW_IF_NOT(is_trained);
1072
1254
 
1073
1255
  for (idx_t i = 0; i < n; i++) {
1074
1256
  for (size_t j = 0; j < d_in; j++) {
@@ -1077,64 +1259,58 @@ void CenteringTransform::apply_noalloc
1077
1259
  }
1078
1260
  }
1079
1261
 
1080
- void CenteringTransform::reverse_transform (idx_t n, const float* xt,
1081
- float* x) const
1082
- {
1083
- FAISS_THROW_IF_NOT (is_trained);
1262
+ void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
1263
+ const {
1264
+ FAISS_THROW_IF_NOT(is_trained);
1084
1265
 
1085
1266
  for (idx_t i = 0; i < n; i++) {
1086
1267
  for (size_t j = 0; j < d_in; j++) {
1087
1268
  *x++ = *xt++ + mean[j];
1088
1269
  }
1089
1270
  }
1090
-
1091
1271
  }
1092
1272
 
1093
-
1094
-
1095
-
1096
-
1097
1273
  /*********************************************
1098
1274
  * RemapDimensionsTransform
1099
1275
  *********************************************/
1100
1276
 
1101
-
1102
- RemapDimensionsTransform::RemapDimensionsTransform (
1103
- int d_in, int d_out, const int *map_in):
1104
- VectorTransform (d_in, d_out)
1105
- {
1106
- map.resize (d_out);
1277
+ RemapDimensionsTransform::RemapDimensionsTransform(
1278
+ int d_in,
1279
+ int d_out,
1280
+ const int* map_in)
1281
+ : VectorTransform(d_in, d_out) {
1282
+ map.resize(d_out);
1107
1283
  for (int i = 0; i < d_out; i++) {
1108
1284
  map[i] = map_in[i];
1109
- FAISS_THROW_IF_NOT (map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1285
+ FAISS_THROW_IF_NOT(map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1110
1286
  }
1111
1287
  }
1112
1288
 
1113
- RemapDimensionsTransform::RemapDimensionsTransform (
1114
- int d_in, int d_out, bool uniform): VectorTransform (d_in, d_out)
1115
- {
1116
- map.resize (d_out, -1);
1289
+ RemapDimensionsTransform::RemapDimensionsTransform(
1290
+ int d_in,
1291
+ int d_out,
1292
+ bool uniform)
1293
+ : VectorTransform(d_in, d_out) {
1294
+ map.resize(d_out, -1);
1117
1295
 
1118
1296
  if (uniform) {
1119
1297
  if (d_in < d_out) {
1120
1298
  for (int i = 0; i < d_in; i++) {
1121
- map [i * d_out / d_in] = i;
1122
- }
1299
+ map[i * d_out / d_in] = i;
1300
+ }
1123
1301
  } else {
1124
1302
  for (int i = 0; i < d_out; i++) {
1125
- map [i] = i * d_in / d_out;
1303
+ map[i] = i * d_in / d_out;
1126
1304
  }
1127
1305
  }
1128
1306
  } else {
1129
1307
  for (int i = 0; i < d_in && i < d_out; i++)
1130
- map [i] = i;
1308
+ map[i] = i;
1131
1309
  }
1132
1310
  }
1133
1311
 
1134
-
1135
- void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1136
- float *xt) const
1137
- {
1312
+ void RemapDimensionsTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1313
+ const {
1138
1314
  for (idx_t i = 0; i < n; i++) {
1139
1315
  for (int j = 0; j < d_out; j++) {
1140
1316
  xt[j] = map[j] < 0 ? 0 : x[map[j]];
@@ -1144,13 +1320,15 @@ void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1144
1320
  }
1145
1321
  }
1146
1322
 
1147
- void RemapDimensionsTransform::reverse_transform (idx_t n, const float * xt,
1148
- float *x) const
1149
- {
1150
- memset (x, 0, sizeof (*x) * n * d_in);
1323
+ void RemapDimensionsTransform::reverse_transform(
1324
+ idx_t n,
1325
+ const float* xt,
1326
+ float* x) const {
1327
+ memset(x, 0, sizeof(*x) * n * d_in);
1151
1328
  for (idx_t i = 0; i < n; i++) {
1152
1329
  for (int j = 0; j < d_out; j++) {
1153
- if (map[j] >= 0) x[map[j]] = xt[j];
1330
+ if (map[j] >= 0)
1331
+ x[map[j]] = xt[j];
1154
1332
  }
1155
1333
  x += d_in;
1156
1334
  xt += d_out;