faiss 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +26 -1
  3. data/README.md +15 -3
  4. data/ext/faiss/ext.cpp +12 -308
  5. data/ext/faiss/extconf.rb +5 -2
  6. data/ext/faiss/index.cpp +189 -0
  7. data/ext/faiss/index_binary.cpp +75 -0
  8. data/ext/faiss/kmeans.cpp +40 -0
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +33 -0
  11. data/ext/faiss/product_quantizer.cpp +53 -0
  12. data/ext/faiss/utils.cpp +13 -0
  13. data/ext/faiss/utils.h +5 -0
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +31 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -10,11 +10,11 @@
10
10
  #ifndef META_INDEXES_H
11
11
  #define META_INDEXES_H
12
12
 
13
- #include <vector>
14
- #include <unordered_map>
15
13
  #include <faiss/Index.h>
16
- #include <faiss/IndexShards.h>
17
14
  #include <faiss/IndexReplicas.h>
15
+ #include <faiss/IndexShards.h>
16
+ #include <unordered_map>
17
+ #include <vector>
18
18
 
19
19
  namespace faiss {
20
20
 
@@ -25,22 +25,25 @@ struct IndexIDMapTemplate : IndexT {
25
25
  using component_t = typename IndexT::component_t;
26
26
  using distance_t = typename IndexT::distance_t;
27
27
 
28
- IndexT * index; ///! the sub-index
29
- bool own_fields; ///! whether pointers are deleted in destructo
28
+ IndexT* index; ///! the sub-index
29
+ bool own_fields; ///! whether pointers are deleted in destructo
30
30
  std::vector<idx_t> id_map;
31
31
 
32
- explicit IndexIDMapTemplate (IndexT *index);
32
+ explicit IndexIDMapTemplate(IndexT* index);
33
33
 
34
34
  /// @param xids if non-null, ids to store for the vectors (size n)
35
- void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override;
35
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
36
+ override;
36
37
 
37
38
  /// this will fail. Use add_with_ids
38
39
  void add(idx_t n, const component_t* x) override;
39
40
 
40
41
  void search(
41
- idx_t n, const component_t* x, idx_t k,
42
- distance_t* distances,
43
- idx_t* labels) const override;
42
+ idx_t n,
43
+ const component_t* x,
44
+ idx_t k,
45
+ distance_t* distances,
46
+ idx_t* labels) const override;
44
47
 
45
48
  void train(idx_t n, const component_t* x) override;
46
49
 
@@ -49,17 +52,22 @@ struct IndexIDMapTemplate : IndexT {
49
52
  /// remove ids adapted to IndexFlat
50
53
  size_t remove_ids(const IDSelector& sel) override;
51
54
 
52
- void range_search (idx_t n, const component_t *x, distance_t radius,
53
- RangeSearchResult *result) const override;
54
-
55
- ~IndexIDMapTemplate () override;
56
- IndexIDMapTemplate () {own_fields=false; index=nullptr; }
55
+ void range_search(
56
+ idx_t n,
57
+ const component_t* x,
58
+ distance_t radius,
59
+ RangeSearchResult* result) const override;
60
+
61
+ ~IndexIDMapTemplate() override;
62
+ IndexIDMapTemplate() {
63
+ own_fields = false;
64
+ index = nullptr;
65
+ }
57
66
  };
58
67
 
59
68
  using IndexIDMap = IndexIDMapTemplate<Index>;
60
69
  using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
61
70
 
62
-
63
71
  /** same as IndexIDMap but also provides an efficient reconstruction
64
72
  * implementation via a 2-way index */
65
73
  template <typename IndexT>
@@ -70,47 +78,47 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
70
78
 
71
79
  std::unordered_map<idx_t, idx_t> rev_map;
72
80
 
73
- explicit IndexIDMap2Template (IndexT *index);
81
+ explicit IndexIDMap2Template(IndexT* index);
74
82
 
75
83
  /// make the rev_map from scratch
76
- void construct_rev_map ();
84
+ void construct_rev_map();
77
85
 
78
- void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override;
86
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
87
+ override;
79
88
 
80
89
  size_t remove_ids(const IDSelector& sel) override;
81
90
 
82
- void reconstruct (idx_t key, component_t * recons) const override;
91
+ void reconstruct(idx_t key, component_t* recons) const override;
83
92
 
84
93
  ~IndexIDMap2Template() override {}
85
- IndexIDMap2Template () {}
94
+ IndexIDMap2Template() {}
86
95
  };
87
96
 
88
97
  using IndexIDMap2 = IndexIDMap2Template<Index>;
89
98
  using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
90
99
 
91
-
92
100
  /** splits input vectors in segments and assigns each segment to a sub-index
93
101
  * used to distribute a MultiIndexQuantizer
94
102
  */
95
- struct IndexSplitVectors: Index {
103
+ struct IndexSplitVectors : Index {
96
104
  bool own_fields;
97
105
  bool threaded;
98
106
  std::vector<Index*> sub_indexes;
99
- idx_t sum_d; /// sum of dimensions seen so far
107
+ idx_t sum_d; /// sum of dimensions seen so far
100
108
 
101
- explicit IndexSplitVectors (idx_t d, bool threaded = false);
109
+ explicit IndexSplitVectors(idx_t d, bool threaded = false);
102
110
 
103
- void add_sub_index (Index *);
104
- void sync_with_sub_indexes ();
111
+ void add_sub_index(Index*);
112
+ void sync_with_sub_indexes();
105
113
 
106
114
  void add(idx_t n, const float* x) override;
107
115
 
108
116
  void search(
109
- idx_t n,
110
- const float* x,
111
- idx_t k,
112
- float* distances,
113
- idx_t* labels) const override;
117
+ idx_t n,
118
+ const float* x,
119
+ idx_t k,
120
+ float* distances,
121
+ idx_t* labels) const override;
114
122
 
115
123
  void train(idx_t n, const float* x) override;
116
124
 
@@ -119,8 +127,6 @@ struct IndexSplitVectors: Index {
119
127
  ~IndexSplitVectors() override;
120
128
  };
121
129
 
122
-
123
130
  } // namespace faiss
124
131
 
125
-
126
132
  #endif
@@ -18,12 +18,12 @@ namespace faiss {
18
18
  /// (brute-force) indices supporting additional metric types for vector
19
19
  /// comparison.
20
20
  enum MetricType {
21
- METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
22
- METRIC_L2 = 1, ///< squared L2 search
23
- METRIC_L1, ///< L1 (aka cityblock)
24
- METRIC_Linf, ///< infinity distance
25
- METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
26
- /// metric_arg
21
+ METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
22
+ METRIC_L2 = 1, ///< squared L2 search
23
+ METRIC_L1, ///< L1 (aka cityblock)
24
+ METRIC_Linf, ///< infinity distance
25
+ METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
26
+ /// metric_arg
27
27
 
28
28
  /// some additional metrics defined in scipy.spatial.distance
29
29
  METRIC_Canberra = 20,
@@ -31,6 +31,6 @@ enum MetricType {
31
31
  METRIC_JensenShannon,
32
32
  };
33
33
 
34
- }
34
+ } // namespace faiss
35
35
 
36
36
  #endif
@@ -10,20 +10,19 @@
10
10
  #include <faiss/VectorTransform.h>
11
11
 
12
12
  #include <cinttypes>
13
- #include <cstdio>
14
13
  #include <cmath>
14
+ #include <cstdio>
15
15
  #include <cstring>
16
16
  #include <memory>
17
17
 
18
+ #include <faiss/IndexPQ.h>
19
+ #include <faiss/impl/FaissAssert.h>
18
20
  #include <faiss/utils/distances.h>
19
21
  #include <faiss/utils/random.h>
20
22
  #include <faiss/utils/utils.h>
21
- #include <faiss/impl/FaissAssert.h>
22
- #include <faiss/IndexPQ.h>
23
23
 
24
24
  using namespace faiss;
25
25
 
26
-
27
26
  extern "C" {
28
27
 
29
28
  // this is to keep the clang syntax checker happy
@@ -31,134 +30,183 @@ extern "C" {
31
30
  #define FINTEGER int
32
31
  #endif
33
32
 
34
-
35
33
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
36
34
 
37
- int sgemm_ (
38
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
39
- n, FINTEGER *k, const float *alpha, const float *a,
40
- FINTEGER *lda, const float *b,
41
- FINTEGER *ldb, float *beta,
42
- float *c, FINTEGER *ldc);
43
-
44
- int dgemm_ (
45
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
46
- n, FINTEGER *k, const double *alpha, const double *a,
47
- FINTEGER *lda, const double *b,
48
- FINTEGER *ldb, double *beta,
49
- double *c, FINTEGER *ldc);
50
-
51
- int ssyrk_ (
52
- const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k,
53
- float *alpha, float *a, FINTEGER *lda,
54
- float *beta, float *c, FINTEGER *ldc);
35
+ int sgemm_(
36
+ const char* transa,
37
+ const char* transb,
38
+ FINTEGER* m,
39
+ FINTEGER* n,
40
+ FINTEGER* k,
41
+ const float* alpha,
42
+ const float* a,
43
+ FINTEGER* lda,
44
+ const float* b,
45
+ FINTEGER* ldb,
46
+ float* beta,
47
+ float* c,
48
+ FINTEGER* ldc);
49
+
50
+ int dgemm_(
51
+ const char* transa,
52
+ const char* transb,
53
+ FINTEGER* m,
54
+ FINTEGER* n,
55
+ FINTEGER* k,
56
+ const double* alpha,
57
+ const double* a,
58
+ FINTEGER* lda,
59
+ const double* b,
60
+ FINTEGER* ldb,
61
+ double* beta,
62
+ double* c,
63
+ FINTEGER* ldc);
64
+
65
+ int ssyrk_(
66
+ const char* uplo,
67
+ const char* trans,
68
+ FINTEGER* n,
69
+ FINTEGER* k,
70
+ float* alpha,
71
+ float* a,
72
+ FINTEGER* lda,
73
+ float* beta,
74
+ float* c,
75
+ FINTEGER* ldc);
55
76
 
56
77
  /* Lapack functions from http://www.netlib.org/clapack/old/single/ */
57
78
 
58
- int ssyev_ (
59
- const char *jobz, const char *uplo, FINTEGER *n, float *a,
60
- FINTEGER *lda, float *w, float *work, FINTEGER *lwork,
61
- FINTEGER *info);
62
-
63
- int dsyev_ (
64
- const char *jobz, const char *uplo, FINTEGER *n, double *a,
65
- FINTEGER *lda, double *w, double *work, FINTEGER *lwork,
66
- FINTEGER *info);
79
+ int ssyev_(
80
+ const char* jobz,
81
+ const char* uplo,
82
+ FINTEGER* n,
83
+ float* a,
84
+ FINTEGER* lda,
85
+ float* w,
86
+ float* work,
87
+ FINTEGER* lwork,
88
+ FINTEGER* info);
89
+
90
+ int dsyev_(
91
+ const char* jobz,
92
+ const char* uplo,
93
+ FINTEGER* n,
94
+ double* a,
95
+ FINTEGER* lda,
96
+ double* w,
97
+ double* work,
98
+ FINTEGER* lwork,
99
+ FINTEGER* info);
67
100
 
68
101
  int sgesvd_(
69
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
70
- float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt,
71
- FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info);
72
-
102
+ const char* jobu,
103
+ const char* jobvt,
104
+ FINTEGER* m,
105
+ FINTEGER* n,
106
+ float* a,
107
+ FINTEGER* lda,
108
+ float* s,
109
+ float* u,
110
+ FINTEGER* ldu,
111
+ float* vt,
112
+ FINTEGER* ldvt,
113
+ float* work,
114
+ FINTEGER* lwork,
115
+ FINTEGER* info);
73
116
 
74
117
  int dgesvd_(
75
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
76
- double *a, FINTEGER *lda, double *s, double *u, FINTEGER *ldu, double *vt,
77
- FINTEGER *ldvt, double *work, FINTEGER *lwork, FINTEGER *info);
78
-
118
+ const char* jobu,
119
+ const char* jobvt,
120
+ FINTEGER* m,
121
+ FINTEGER* n,
122
+ double* a,
123
+ FINTEGER* lda,
124
+ double* s,
125
+ double* u,
126
+ FINTEGER* ldu,
127
+ double* vt,
128
+ FINTEGER* ldvt,
129
+ double* work,
130
+ FINTEGER* lwork,
131
+ FINTEGER* info);
79
132
  }
80
133
 
81
134
  /*********************************************
82
135
  * VectorTransform
83
136
  *********************************************/
84
137
 
85
-
86
-
87
- float * VectorTransform::apply (Index::idx_t n, const float * x) const
88
- {
89
- float * xt = new float[n * d_out];
90
- apply_noalloc (n, x, xt);
138
+ float* VectorTransform::apply(Index::idx_t n, const float* x) const {
139
+ float* xt = new float[n * d_out];
140
+ apply_noalloc(n, x, xt);
91
141
  return xt;
92
142
  }
93
143
 
94
-
95
- void VectorTransform::train (idx_t, const float *) {
144
+ void VectorTransform::train(idx_t, const float*) {
96
145
  // does nothing by default
97
146
  }
98
147
 
99
-
100
- void VectorTransform::reverse_transform (
101
- idx_t , const float *,
102
- float *) const
103
- {
104
- FAISS_THROW_MSG ("reverse transform not implemented");
148
+ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
149
+ FAISS_THROW_MSG("reverse transform not implemented");
105
150
  }
106
151
 
107
-
108
-
109
-
110
152
  /*********************************************
111
153
  * LinearTransform
112
154
  *********************************************/
113
155
 
114
156
  /// both d_in > d_out and d_out < d_in are supported
115
- LinearTransform::LinearTransform (int d_in, int d_out,
116
- bool have_bias):
117
- VectorTransform (d_in, d_out), have_bias (have_bias),
118
- is_orthonormal (false), verbose (false)
119
- {
157
+ LinearTransform::LinearTransform(int d_in, int d_out, bool have_bias)
158
+ : VectorTransform(d_in, d_out),
159
+ have_bias(have_bias),
160
+ is_orthonormal(false),
161
+ verbose(false) {
120
162
  is_trained = false; // will be trained when A and b are initialized
121
163
  }
122
164
 
123
- void LinearTransform::apply_noalloc (Index::idx_t n, const float * x,
124
- float * xt) const
125
- {
165
+ void LinearTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
166
+ const {
126
167
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
127
168
 
128
169
  float c_factor;
129
170
  if (have_bias) {
130
- FAISS_THROW_IF_NOT_MSG (b.size() == d_out, "Bias not initialized");
131
- float * xi = xt;
171
+ FAISS_THROW_IF_NOT_MSG(b.size() == d_out, "Bias not initialized");
172
+ float* xi = xt;
132
173
  for (int i = 0; i < n; i++)
133
- for(int j = 0; j < d_out; j++)
174
+ for (int j = 0; j < d_out; j++)
134
175
  *xi++ = b[j];
135
176
  c_factor = 1.0;
136
177
  } else {
137
178
  c_factor = 0.0;
138
179
  }
139
180
 
140
- FAISS_THROW_IF_NOT_MSG (A.size() == d_out * d_in,
141
- "Transformation matrix not initialized");
181
+ FAISS_THROW_IF_NOT_MSG(
182
+ A.size() == d_out * d_in, "Transformation matrix not initialized");
142
183
 
143
184
  float one = 1;
144
185
  FINTEGER nbiti = d_out, ni = n, di = d_in;
145
- sgemm_ ("Transposed", "Not transposed",
146
- &nbiti, &ni, &di,
147
- &one, A.data(), &di, x, &di, &c_factor, xt, &nbiti);
148
-
186
+ sgemm_("Transposed",
187
+ "Not transposed",
188
+ &nbiti,
189
+ &ni,
190
+ &di,
191
+ &one,
192
+ A.data(),
193
+ &di,
194
+ x,
195
+ &di,
196
+ &c_factor,
197
+ xt,
198
+ &nbiti);
149
199
  }
150
200
 
151
-
152
- void LinearTransform::transform_transpose (idx_t n, const float * y,
153
- float *x) const
154
- {
201
+ void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
202
+ const {
155
203
  if (have_bias) { // allocate buffer to store bias-corrected data
156
- float *y_new = new float [n * d_out];
157
- const float *yr = y;
158
- float *yw = y_new;
204
+ float* y_new = new float[n * d_out];
205
+ const float* yr = y;
206
+ float* yw = y_new;
159
207
  for (idx_t i = 0; i < n; i++) {
160
208
  for (int j = 0; j < d_out; j++) {
161
- *yw++ = *yr++ - b [j];
209
+ *yw++ = *yr++ - b[j];
162
210
  }
163
211
  }
164
212
  y = y_new;
@@ -167,15 +215,26 @@ void LinearTransform::transform_transpose (idx_t n, const float * y,
167
215
  {
168
216
  FINTEGER dii = d_in, doi = d_out, ni = n;
169
217
  float one = 1.0, zero = 0.0;
170
- sgemm_ ("Not", "Not", &dii, &ni, &doi,
171
- &one, A.data (), &dii, y, &doi, &zero, x, &dii);
218
+ sgemm_("Not",
219
+ "Not",
220
+ &dii,
221
+ &ni,
222
+ &doi,
223
+ &one,
224
+ A.data(),
225
+ &dii,
226
+ y,
227
+ &doi,
228
+ &zero,
229
+ x,
230
+ &dii);
172
231
  }
173
232
 
174
- if (have_bias) delete [] y;
233
+ if (have_bias)
234
+ delete[] y;
175
235
  }
176
236
 
177
- void LinearTransform::set_is_orthonormal ()
178
- {
237
+ void LinearTransform::set_is_orthonormal() {
179
238
  if (d_out > d_in) {
180
239
  // not clear what we should do in this case
181
240
  is_orthonormal = false;
@@ -193,44 +252,53 @@ void LinearTransform::set_is_orthonormal ()
193
252
  FINTEGER dii = d_in, doi = d_out;
194
253
  float one = 1.0, zero = 0.0;
195
254
 
196
- sgemm_ ("Transposed", "Not", &doi, &doi, &dii,
197
- &one, A.data (), &dii,
198
- A.data(), &dii,
199
- &zero, ATA.data(), &doi);
255
+ sgemm_("Transposed",
256
+ "Not",
257
+ &doi,
258
+ &doi,
259
+ &dii,
260
+ &one,
261
+ A.data(),
262
+ &dii,
263
+ A.data(),
264
+ &dii,
265
+ &zero,
266
+ ATA.data(),
267
+ &doi);
200
268
 
201
269
  is_orthonormal = true;
202
270
  for (long i = 0; i < d_out; i++) {
203
271
  for (long j = 0; j < d_out; j++) {
204
272
  float v = ATA[i + j * d_out];
205
- if (i == j) v-= 1;
273
+ if (i == j)
274
+ v -= 1;
206
275
  if (fabs(v) > eps) {
207
276
  is_orthonormal = false;
208
277
  }
209
278
  }
210
279
  }
211
280
  }
212
-
213
281
  }
214
282
 
215
-
216
- void LinearTransform::reverse_transform (idx_t n, const float * xt,
217
- float *x) const
218
- {
283
+ void LinearTransform::reverse_transform(idx_t n, const float* xt, float* x)
284
+ const {
219
285
  if (is_orthonormal) {
220
- transform_transpose (n, xt, x);
286
+ transform_transpose(n, xt, x);
221
287
  } else {
222
- FAISS_THROW_MSG ("reverse transform not implemented for non-orthonormal matrices");
288
+ FAISS_THROW_MSG(
289
+ "reverse transform not implemented for non-orthonormal matrices");
223
290
  }
224
291
  }
225
292
 
226
-
227
- void LinearTransform::print_if_verbose (
228
- const char*name, const std::vector<double> &mat,
229
- int n, int d) const
230
- {
231
- if (!verbose) return;
293
+ void LinearTransform::print_if_verbose(
294
+ const char* name,
295
+ const std::vector<double>& mat,
296
+ int n,
297
+ int d) const {
298
+ if (!verbose)
299
+ return;
232
300
  printf("matrix %s: %d*%d [\n", name, n, d);
233
- FAISS_THROW_IF_NOT (mat.size() >= n * d);
301
+ FAISS_THROW_IF_NOT(mat.size() >= n * d);
234
302
  for (int i = 0; i < n; i++) {
235
303
  for (int j = 0; j < d; j++) {
236
304
  printf("%10.5g ", mat[i * d + j]);
@@ -244,24 +312,22 @@ void LinearTransform::print_if_verbose (
244
312
  * RandomRotationMatrix
245
313
  *********************************************/
246
314
 
247
- void RandomRotationMatrix::init (int seed)
248
- {
249
-
250
- if(d_out <= d_in) {
251
- A.resize (d_out * d_in);
252
- float *q = A.data();
315
+ void RandomRotationMatrix::init(int seed) {
316
+ if (d_out <= d_in) {
317
+ A.resize(d_out * d_in);
318
+ float* q = A.data();
253
319
  float_randn(q, d_out * d_in, seed);
254
320
  matrix_qr(d_in, d_out, q);
255
321
  } else {
256
322
  // use tight-frame transformation
257
- A.resize (d_out * d_out);
258
- float *q = A.data();
323
+ A.resize(d_out * d_out);
324
+ float* q = A.data();
259
325
  float_randn(q, d_out * d_out, seed);
260
326
  matrix_qr(d_out, d_out, q);
261
327
  // remove columns
262
328
  int i, j;
263
329
  for (i = 0; i < d_out; i++) {
264
- for(j = 0; j < d_in; j++) {
330
+ for (j = 0; j < d_in; j++) {
265
331
  q[i * d_in + j] = q[i * d_out + j];
266
332
  }
267
333
  }
@@ -271,247 +337,280 @@ void RandomRotationMatrix::init (int seed)
271
337
  is_trained = true;
272
338
  }
273
339
 
274
- void RandomRotationMatrix::train (Index::idx_t /*n*/, const float * /*x*/)
275
- {
340
+ void RandomRotationMatrix::train(Index::idx_t /*n*/, const float* /*x*/) {
276
341
  // initialize with some arbitrary seed
277
- init (12345);
342
+ init(12345);
278
343
  }
279
344
 
280
-
281
345
  /*********************************************
282
346
  * PCAMatrix
283
347
  *********************************************/
284
348
 
285
- PCAMatrix::PCAMatrix (int d_in, int d_out,
286
- float eigen_power, bool random_rotation):
287
- LinearTransform(d_in, d_out, true),
288
- eigen_power(eigen_power), random_rotation(random_rotation)
289
- {
349
+ PCAMatrix::PCAMatrix(
350
+ int d_in,
351
+ int d_out,
352
+ float eigen_power,
353
+ bool random_rotation)
354
+ : LinearTransform(d_in, d_out, true),
355
+ eigen_power(eigen_power),
356
+ random_rotation(random_rotation) {
290
357
  is_trained = false;
291
358
  max_points_per_d = 1000;
292
359
  balanced_bins = 0;
293
360
  }
294
361
 
295
-
296
362
  namespace {
297
363
 
298
364
  /// Compute the eigenvalue decomposition of symmetric matrix cov,
299
365
  /// dimensions d_in-by-d_in. Output eigenvectors in cov.
300
366
 
301
- void eig(size_t d_in, double *cov, double *eigenvalues, int verbose)
302
- {
367
+ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
303
368
  { // compute eigenvalues and vectors
304
369
  FINTEGER info = 0, lwork = -1, di = d_in;
305
370
  double workq;
306
371
 
307
- dsyev_ ("Vectors as well", "Upper",
308
- &di, cov, &di, eigenvalues, &workq, &lwork, &info);
372
+ dsyev_("Vectors as well",
373
+ "Upper",
374
+ &di,
375
+ cov,
376
+ &di,
377
+ eigenvalues,
378
+ &workq,
379
+ &lwork,
380
+ &info);
309
381
  lwork = FINTEGER(workq);
310
- double *work = new double[lwork];
382
+ double* work = new double[lwork];
311
383
 
312
- dsyev_ ("Vectors as well", "Upper",
313
- &di, cov, &di, eigenvalues, work, &lwork, &info);
384
+ dsyev_("Vectors as well",
385
+ "Upper",
386
+ &di,
387
+ cov,
388
+ &di,
389
+ eigenvalues,
390
+ work,
391
+ &lwork,
392
+ &info);
314
393
 
315
- delete [] work;
394
+ delete[] work;
316
395
 
317
396
  if (info != 0) {
318
- fprintf (stderr, "WARN ssyev info returns %d, "
319
- "a very bad PCA matrix is learnt\n",
320
- int(info));
397
+ fprintf(stderr,
398
+ "WARN ssyev info returns %d, "
399
+ "a very bad PCA matrix is learnt\n",
400
+ int(info));
321
401
  // do not throw exception, as the matrix could still be useful
322
402
  }
323
403
 
324
-
325
- if(verbose && d_in <= 10) {
404
+ if (verbose && d_in <= 10) {
326
405
  printf("info=%ld new eigvals=[", long(info));
327
- for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]);
406
+ for (int j = 0; j < d_in; j++)
407
+ printf("%g ", eigenvalues[j]);
328
408
  printf("]\n");
329
409
 
330
- double *ci = cov;
410
+ double* ci = cov;
331
411
  printf("eigenvecs=\n");
332
- for(int i = 0; i < d_in; i++) {
333
- for(int j = 0; j < d_in; j++)
412
+ for (int i = 0; i < d_in; i++) {
413
+ for (int j = 0; j < d_in; j++)
334
414
  printf("%10.4g ", *ci++);
335
415
  printf("\n");
336
416
  }
337
417
  }
338
-
339
418
  }
340
419
 
341
420
  // revert order of eigenvectors & values
342
421
 
343
- for(int i = 0; i < d_in / 2; i++) {
344
-
422
+ for (int i = 0; i < d_in / 2; i++) {
345
423
  std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
346
- double *v1 = cov + i * d_in;
347
- double *v2 = cov + (d_in - 1 - i) * d_in;
348
- for(int j = 0; j < d_in; j++)
424
+ double* v1 = cov + i * d_in;
425
+ double* v2 = cov + (d_in - 1 - i) * d_in;
426
+ for (int j = 0; j < d_in; j++)
349
427
  std::swap(v1[j], v2[j]);
350
428
  }
351
-
352
429
  }
353
430
 
431
+ } // namespace
354
432
 
355
- }
356
-
357
- void PCAMatrix::train (Index::idx_t n, const float *x)
358
- {
359
- const float * x_in = x;
433
+ void PCAMatrix::train(Index::idx_t n, const float* x) {
434
+ const float* x_in = x;
360
435
 
361
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
362
- max_points_per_d * d_in, x, verbose);
436
+ x = fvecs_maybe_subsample(
437
+ d_in, (size_t*)&n, max_points_per_d * d_in, x, verbose);
363
438
 
364
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
439
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
365
440
 
366
441
  // compute mean
367
- mean.clear(); mean.resize(d_in, 0.0);
442
+ mean.clear();
443
+ mean.resize(d_in, 0.0);
368
444
  if (have_bias) { // we may want to skip the bias
369
- const float *xi = x;
445
+ const float* xi = x;
370
446
  for (int i = 0; i < n; i++) {
371
- for(int j = 0; j < d_in; j++)
447
+ for (int j = 0; j < d_in; j++)
372
448
  mean[j] += *xi++;
373
449
  }
374
- for(int j = 0; j < d_in; j++)
450
+ for (int j = 0; j < d_in; j++)
375
451
  mean[j] /= n;
376
452
  }
377
- if(verbose) {
453
+ if (verbose) {
378
454
  printf("mean=[");
379
- for(int j = 0; j < d_in; j++) printf("%g ", mean[j]);
455
+ for (int j = 0; j < d_in; j++)
456
+ printf("%g ", mean[j]);
380
457
  printf("]\n");
381
458
  }
382
459
 
383
- if(n >= d_in) {
460
+ if (n >= d_in) {
384
461
  // compute covariance matrix, store it in PCA matrix
385
462
  PCAMat.resize(d_in * d_in);
386
- float * cov = PCAMat.data();
463
+ float* cov = PCAMat.data();
387
464
  { // initialize with mean * mean^T term
388
- float *ci = cov;
389
- for(int i = 0; i < d_in; i++) {
390
- for(int j = 0; j < d_in; j++)
391
- *ci++ = - n * mean[i] * mean[j];
465
+ float* ci = cov;
466
+ for (int i = 0; i < d_in; i++) {
467
+ for (int j = 0; j < d_in; j++)
468
+ *ci++ = -n * mean[i] * mean[j];
392
469
  }
393
470
  }
394
471
  {
395
472
  FINTEGER di = d_in, ni = n;
396
473
  float one = 1.0;
397
- ssyrk_ ("Up", "Non transposed",
398
- &di, &ni, &one, (float*)x, &di, &one, cov, &di);
399
-
474
+ ssyrk_("Up",
475
+ "Non transposed",
476
+ &di,
477
+ &ni,
478
+ &one,
479
+ (float*)x,
480
+ &di,
481
+ &one,
482
+ cov,
483
+ &di);
400
484
  }
401
- if(verbose && d_in <= 10) {
402
- float *ci = cov;
485
+ if (verbose && d_in <= 10) {
486
+ float* ci = cov;
403
487
  printf("cov=\n");
404
- for(int i = 0; i < d_in; i++) {
405
- for(int j = 0; j < d_in; j++)
488
+ for (int i = 0; i < d_in; i++) {
489
+ for (int j = 0; j < d_in; j++)
406
490
  printf("%10g ", *ci++);
407
491
  printf("\n");
408
492
  }
409
493
  }
410
494
 
411
- std::vector<double> covd (d_in * d_in);
412
- for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i];
495
+ std::vector<double> covd(d_in * d_in);
496
+ for (size_t i = 0; i < d_in * d_in; i++)
497
+ covd[i] = cov[i];
413
498
 
414
- std::vector<double> eigenvaluesd (d_in);
499
+ std::vector<double> eigenvaluesd(d_in);
415
500
 
416
- eig (d_in, covd.data (), eigenvaluesd.data (), verbose);
501
+ eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
417
502
 
418
- for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i];
419
- eigenvalues.resize (d_in);
503
+ for (size_t i = 0; i < d_in * d_in; i++)
504
+ PCAMat[i] = covd[i];
505
+ eigenvalues.resize(d_in);
420
506
 
421
507
  for (size_t i = 0; i < d_in; i++)
422
- eigenvalues [i] = eigenvaluesd [i];
423
-
508
+ eigenvalues[i] = eigenvaluesd[i];
424
509
 
425
510
  } else {
426
-
427
- std::vector<float> xc (n * d_in);
511
+ std::vector<float> xc(n * d_in);
428
512
 
429
513
  for (size_t i = 0; i < n; i++)
430
- for(size_t j = 0; j < d_in; j++)
431
- xc [i * d_in + j] = x [i * d_in + j] - mean[j];
514
+ for (size_t j = 0; j < d_in; j++)
515
+ xc[i * d_in + j] = x[i * d_in + j] - mean[j];
432
516
 
433
517
  // compute Gram matrix
434
- std::vector<float> gram (n * n);
518
+ std::vector<float> gram(n * n);
435
519
  {
436
520
  FINTEGER di = d_in, ni = n;
437
521
  float one = 1.0, zero = 0.0;
438
- ssyrk_ ("Up", "Transposed",
439
- &ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni);
522
+ ssyrk_("Up",
523
+ "Transposed",
524
+ &ni,
525
+ &di,
526
+ &one,
527
+ xc.data(),
528
+ &di,
529
+ &zero,
530
+ gram.data(),
531
+ &ni);
440
532
  }
441
533
 
442
- if(verbose && d_in <= 10) {
443
- float *ci = gram.data();
534
+ if (verbose && d_in <= 10) {
535
+ float* ci = gram.data();
444
536
  printf("gram=\n");
445
- for(int i = 0; i < n; i++) {
446
- for(int j = 0; j < n; j++)
537
+ for (int i = 0; i < n; i++) {
538
+ for (int j = 0; j < n; j++)
447
539
  printf("%10g ", *ci++);
448
540
  printf("\n");
449
541
  }
450
542
  }
451
543
 
452
- std::vector<double> gramd (n * n);
544
+ std::vector<double> gramd(n * n);
453
545
  for (size_t i = 0; i < n * n; i++)
454
- gramd [i] = gram [i];
546
+ gramd[i] = gram[i];
455
547
 
456
- std::vector<double> eigenvaluesd (n);
548
+ std::vector<double> eigenvaluesd(n);
457
549
 
458
550
  // eig will fill in only the n first eigenvals
459
551
 
460
- eig (n, gramd.data (), eigenvaluesd.data (), verbose);
552
+ eig(n, gramd.data(), eigenvaluesd.data(), verbose);
461
553
 
462
554
  PCAMat.resize(d_in * n);
463
555
 
464
556
  for (size_t i = 0; i < n * n; i++)
465
- gram [i] = gramd [i];
557
+ gram[i] = gramd[i];
466
558
 
467
- eigenvalues.resize (d_in);
559
+ eigenvalues.resize(d_in);
468
560
  // fill in only the n first ones
469
561
  for (size_t i = 0; i < n; i++)
470
- eigenvalues [i] = eigenvaluesd [i];
562
+ eigenvalues[i] = eigenvaluesd[i];
471
563
 
472
564
  { // compute PCAMat = x' * v
473
565
  FINTEGER di = d_in, ni = n;
474
566
  float one = 1.0;
475
567
 
476
- sgemm_ ("Non", "Non Trans",
477
- &di, &ni, &ni,
478
- &one, xc.data(), &di, gram.data(), &ni,
479
- &one, PCAMat.data(), &di);
568
+ sgemm_("Non",
569
+ "Non Trans",
570
+ &di,
571
+ &ni,
572
+ &ni,
573
+ &one,
574
+ xc.data(),
575
+ &di,
576
+ gram.data(),
577
+ &ni,
578
+ &one,
579
+ PCAMat.data(),
580
+ &di);
480
581
  }
481
582
 
482
- if(verbose && d_in <= 10) {
483
- float *ci = PCAMat.data();
583
+ if (verbose && d_in <= 10) {
584
+ float* ci = PCAMat.data();
484
585
  printf("PCAMat=\n");
485
- for(int i = 0; i < n; i++) {
486
- for(int j = 0; j < d_in; j++)
586
+ for (int i = 0; i < n; i++) {
587
+ for (int j = 0; j < d_in; j++)
487
588
  printf("%10g ", *ci++);
488
589
  printf("\n");
489
590
  }
490
591
  }
491
- fvec_renorm_L2 (d_in, n, PCAMat.data());
492
-
592
+ fvec_renorm_L2(d_in, n, PCAMat.data());
493
593
  }
494
594
 
495
595
  prepare_Ab();
496
596
  is_trained = true;
497
597
  }
498
598
 
499
- void PCAMatrix::copy_from (const PCAMatrix & other)
500
- {
501
- FAISS_THROW_IF_NOT (other.is_trained);
599
+ void PCAMatrix::copy_from(const PCAMatrix& other) {
600
+ FAISS_THROW_IF_NOT(other.is_trained);
502
601
  mean = other.mean;
503
602
  eigenvalues = other.eigenvalues;
504
603
  PCAMat = other.PCAMat;
505
- prepare_Ab ();
604
+ prepare_Ab();
506
605
  is_trained = true;
507
606
  }
508
607
 
509
- void PCAMatrix::prepare_Ab ()
510
- {
511
- FAISS_THROW_IF_NOT_FMT (
608
+ void PCAMatrix::prepare_Ab() {
609
+ FAISS_THROW_IF_NOT_FMT(
512
610
  d_out * d_in <= PCAMat.size(),
513
611
  "PCA matrix cannot output %d dimensions from %d ",
514
- d_out, d_in);
612
+ d_out,
613
+ d_in);
515
614
 
516
615
  if (!random_rotation) {
517
616
  A = PCAMat;
@@ -519,23 +618,23 @@ void PCAMatrix::prepare_Ab ()
519
618
 
520
619
  // first scale the components
521
620
  if (eigen_power != 0) {
522
- float *ai = A.data();
621
+ float* ai = A.data();
523
622
  for (int i = 0; i < d_out; i++) {
524
623
  float factor = pow(eigenvalues[i], eigen_power);
525
- for(int j = 0; j < d_in; j++)
624
+ for (int j = 0; j < d_in; j++)
526
625
  *ai++ *= factor;
527
626
  }
528
627
  }
529
628
 
530
629
  if (balanced_bins != 0) {
531
- FAISS_THROW_IF_NOT (d_out % balanced_bins == 0);
630
+ FAISS_THROW_IF_NOT(d_out % balanced_bins == 0);
532
631
  int dsub = d_out / balanced_bins;
533
- std::vector <float> Ain;
632
+ std::vector<float> Ain;
534
633
  std::swap(A, Ain);
535
634
  A.resize(d_out * d_in);
536
635
 
537
- std::vector <float> accu(balanced_bins);
538
- std::vector <int> counter(balanced_bins);
636
+ std::vector<float> accu(balanced_bins);
637
+ std::vector<int> counter(balanced_bins);
539
638
 
540
639
  // greedy assignment
541
640
  for (int i = 0; i < d_out; i++) {
@@ -550,9 +649,8 @@ void PCAMatrix::prepare_Ab ()
550
649
  }
551
650
  int row_dst = best_j * dsub + counter[best_j];
552
651
  accu[best_j] += eigenvalues[i];
553
- counter[best_j] ++;
554
- memcpy (&A[row_dst * d_in], &Ain[i * d_in],
555
- d_in * sizeof (A[0]));
652
+ counter[best_j]++;
653
+ memcpy(&A[row_dst * d_in], &Ain[i * d_in], d_in * sizeof(A[0]));
556
654
  }
557
655
 
558
656
  if (verbose) {
@@ -563,11 +661,11 @@ void PCAMatrix::prepare_Ab ()
563
661
  }
564
662
  }
565
663
 
566
-
567
664
  } else {
568
- FAISS_THROW_IF_NOT_MSG (balanced_bins == 0,
569
- "both balancing bins and applying a random rotation "
570
- "does not make sense");
665
+ FAISS_THROW_IF_NOT_MSG(
666
+ balanced_bins == 0,
667
+ "both balancing bins and applying a random rotation "
668
+ "does not make sense");
571
669
  RandomRotationMatrix rr(d_out, d_out);
572
670
 
573
671
  rr.init(5);
@@ -576,8 +674,8 @@ void PCAMatrix::prepare_Ab ()
576
674
  if (eigen_power != 0) {
577
675
  for (int i = 0; i < d_out; i++) {
578
676
  float factor = pow(eigenvalues[i], eigen_power);
579
- for(int j = 0; j < d_out; j++)
580
- rr.A[j * d_out + i] *= factor;
677
+ for (int j = 0; j < d_out; j++)
678
+ rr.A[j * d_out + i] *= factor;
581
679
  }
582
680
  }
583
681
 
@@ -586,15 +684,24 @@ void PCAMatrix::prepare_Ab ()
586
684
  FINTEGER dii = d_in, doo = d_out;
587
685
  float one = 1.0, zero = 0.0;
588
686
 
589
- sgemm_ ("Not", "Not", &dii, &doo, &doo,
590
- &one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero,
591
- A.data(), &dii);
592
-
687
+ sgemm_("Not",
688
+ "Not",
689
+ &dii,
690
+ &doo,
691
+ &doo,
692
+ &one,
693
+ PCAMat.data(),
694
+ &dii,
695
+ rr.A.data(),
696
+ &doo,
697
+ &zero,
698
+ A.data(),
699
+ &dii);
593
700
  }
594
-
595
701
  }
596
702
 
597
- b.clear(); b.resize(d_out);
703
+ b.clear();
704
+ b.resize(d_out);
598
705
 
599
706
  for (int i = 0; i < d_out; i++) {
600
707
  float accu = 0;
@@ -604,57 +711,61 @@ void PCAMatrix::prepare_Ab ()
604
711
  }
605
712
 
606
713
  is_orthonormal = eigen_power == 0;
607
-
608
714
  }
609
715
 
610
716
  /*********************************************
611
717
  * ITQMatrix
612
718
  *********************************************/
613
719
 
614
- ITQMatrix::ITQMatrix (int d):
615
- LinearTransform(d, d, false),
616
- max_iter (50),
617
- seed (123)
618
- {
619
- }
620
-
720
+ ITQMatrix::ITQMatrix(int d)
721
+ : LinearTransform(d, d, false), max_iter(50), seed(123) {}
621
722
 
622
723
  /** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */
623
- void ITQMatrix::train (Index::idx_t n, const float* xf)
624
- {
724
+ void ITQMatrix::train(Index::idx_t n, const float* xf) {
625
725
  size_t d = d_in;
626
- std::vector<double> rotation (d * d);
726
+ std::vector<double> rotation(d * d);
627
727
 
628
728
  if (init_rotation.size() == d * d) {
629
- memcpy (rotation.data(), init_rotation.data(),
630
- d * d * sizeof(rotation[0]));
729
+ memcpy(rotation.data(),
730
+ init_rotation.data(),
731
+ d * d * sizeof(rotation[0]));
631
732
  } else {
632
- RandomRotationMatrix rrot (d, d);
633
- rrot.init (seed);
733
+ RandomRotationMatrix rrot(d, d);
734
+ rrot.init(seed);
634
735
  for (size_t i = 0; i < d * d; i++) {
635
736
  rotation[i] = rrot.A[i];
636
737
  }
637
738
  }
638
739
 
639
- std::vector<double> x (n * d);
740
+ std::vector<double> x(n * d);
640
741
 
641
742
  for (size_t i = 0; i < n * d; i++) {
642
743
  x[i] = xf[i];
643
744
  }
644
745
 
645
- std::vector<double> rotated_x (n * d), cov_mat (d * d);
646
- std::vector<double> u (d * d), vt (d * d), singvals (d);
746
+ std::vector<double> rotated_x(n * d), cov_mat(d * d);
747
+ std::vector<double> u(d * d), vt(d * d), singvals(d);
647
748
 
648
749
  for (int i = 0; i < max_iter; i++) {
649
- print_if_verbose ("rotation", rotation, d, d);
750
+ print_if_verbose("rotation", rotation, d, d);
650
751
  { // rotated_data = np.dot(training_data, rotation)
651
752
  FINTEGER di = d, ni = n;
652
753
  double one = 1, zero = 0;
653
- dgemm_ ("N", "N", &di, &ni, &di,
654
- &one, rotation.data(), &di, x.data(), &di,
655
- &zero, rotated_x.data(), &di);
754
+ dgemm_("N",
755
+ "N",
756
+ &di,
757
+ &ni,
758
+ &di,
759
+ &one,
760
+ rotation.data(),
761
+ &di,
762
+ x.data(),
763
+ &di,
764
+ &zero,
765
+ rotated_x.data(),
766
+ &di);
656
767
  }
657
- print_if_verbose ("rotated_x", rotated_x, n, d);
768
+ print_if_verbose("rotated_x", rotated_x, n, d);
658
769
  // binarize
659
770
  for (size_t j = 0; j < n * d; j++) {
660
771
  rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
@@ -663,88 +774,119 @@ void ITQMatrix::train (Index::idx_t n, const float* xf)
663
774
  { // rotated_data = np.dot(training_data, rotation)
664
775
  FINTEGER di = d, ni = n;
665
776
  double one = 1, zero = 0;
666
- dgemm_ ("N", "T", &di, &di, &ni,
667
- &one, rotated_x.data(), &di, x.data(), &di,
668
- &zero, cov_mat.data(), &di);
777
+ dgemm_("N",
778
+ "T",
779
+ &di,
780
+ &di,
781
+ &ni,
782
+ &one,
783
+ rotated_x.data(),
784
+ &di,
785
+ x.data(),
786
+ &di,
787
+ &zero,
788
+ cov_mat.data(),
789
+ &di);
669
790
  }
670
- print_if_verbose ("cov_mat", cov_mat, d, d);
791
+ print_if_verbose("cov_mat", cov_mat, d, d);
671
792
  // SVD
672
793
  {
673
-
674
794
  FINTEGER di = d;
675
795
  FINTEGER lwork = -1, info;
676
796
  double lwork1;
677
797
 
678
798
  // workspace query
679
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
680
- singvals.data(), u.data(), &di,
681
- vt.data(), &di,
682
- &lwork1, &lwork, &info);
683
-
684
- FAISS_THROW_IF_NOT (info == 0);
685
- lwork = size_t (lwork1);
686
- std::vector<double> work (lwork);
687
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
688
- singvals.data(), u.data(), &di,
689
- vt.data(), &di,
690
- work.data(), &lwork, &info);
691
- FAISS_THROW_IF_NOT_FMT (info == 0, "sgesvd returned info=%d", info);
692
-
799
+ dgesvd_("A",
800
+ "A",
801
+ &di,
802
+ &di,
803
+ cov_mat.data(),
804
+ &di,
805
+ singvals.data(),
806
+ u.data(),
807
+ &di,
808
+ vt.data(),
809
+ &di,
810
+ &lwork1,
811
+ &lwork,
812
+ &info);
813
+
814
+ FAISS_THROW_IF_NOT(info == 0);
815
+ lwork = size_t(lwork1);
816
+ std::vector<double> work(lwork);
817
+ dgesvd_("A",
818
+ "A",
819
+ &di,
820
+ &di,
821
+ cov_mat.data(),
822
+ &di,
823
+ singvals.data(),
824
+ u.data(),
825
+ &di,
826
+ vt.data(),
827
+ &di,
828
+ work.data(),
829
+ &lwork,
830
+ &info);
831
+ FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
693
832
  }
694
- print_if_verbose ("u", u, d, d);
695
- print_if_verbose ("vt", vt, d, d);
833
+ print_if_verbose("u", u, d, d);
834
+ print_if_verbose("vt", vt, d, d);
696
835
  // update rotation
697
836
  {
698
837
  FINTEGER di = d;
699
838
  double one = 1, zero = 0;
700
- dgemm_ ("N", "T", &di, &di, &di,
701
- &one, u.data(), &di, vt.data(), &di,
702
- &zero, rotation.data(), &di);
839
+ dgemm_("N",
840
+ "T",
841
+ &di,
842
+ &di,
843
+ &di,
844
+ &one,
845
+ u.data(),
846
+ &di,
847
+ vt.data(),
848
+ &di,
849
+ &zero,
850
+ rotation.data(),
851
+ &di);
703
852
  }
704
- print_if_verbose ("final rot", rotation, d, d);
705
-
853
+ print_if_verbose("final rot", rotation, d, d);
706
854
  }
707
- A.resize (d * d);
855
+ A.resize(d * d);
708
856
  for (size_t i = 0; i < d; i++) {
709
857
  for (size_t j = 0; j < d; j++) {
710
858
  A[i + d * j] = rotation[j + d * i];
711
859
  }
712
860
  }
713
861
  is_trained = true;
714
-
715
862
  }
716
863
 
717
- ITQTransform::ITQTransform (int d_in, int d_out, bool do_pca):
718
- VectorTransform (d_in, d_out),
719
- do_pca (do_pca),
720
- itq (d_out),
721
- pca_then_itq (d_in, d_out, false)
722
- {
864
+ ITQTransform::ITQTransform(int d_in, int d_out, bool do_pca)
865
+ : VectorTransform(d_in, d_out),
866
+ do_pca(do_pca),
867
+ itq(d_out),
868
+ pca_then_itq(d_in, d_out, false) {
723
869
  if (!do_pca) {
724
- FAISS_THROW_IF_NOT (d_in == d_out);
870
+ FAISS_THROW_IF_NOT(d_in == d_out);
725
871
  }
726
872
  max_train_per_dim = 10;
727
873
  is_trained = false;
728
874
  }
729
875
 
876
+ void ITQTransform::train(idx_t n, const float* x) {
877
+ FAISS_THROW_IF_NOT(!is_trained);
730
878
 
731
-
732
-
733
- void ITQTransform::train (idx_t n, const float *x)
734
- {
735
- FAISS_THROW_IF_NOT (!is_trained);
736
-
737
- const float * x_in = x;
879
+ const float* x_in = x;
738
880
  size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
739
- x = fvecs_maybe_subsample (d_in, (size_t*)&n, max_train_points, x);
881
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x);
740
882
 
741
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
883
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
742
884
 
743
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
885
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
744
886
  { // normalize
745
887
  int d = d_in;
746
888
 
747
- mean.resize (d, 0);
889
+ mean.resize(d, 0);
748
890
  for (idx_t i = 0; i < n; i++) {
749
891
  for (idx_t j = 0; j < d; j++) {
750
892
  mean[j] += x[i * d + j];
@@ -755,38 +897,47 @@ void ITQTransform::train (idx_t n, const float *x)
755
897
  }
756
898
  for (idx_t i = 0; i < n; i++) {
757
899
  for (idx_t j = 0; j < d; j++) {
758
- x_norm[i * d + j] = x[i * d + j] - mean[j];
900
+ x_norm[i * d + j] = x[i * d + j] - mean[j];
759
901
  }
760
902
  }
761
- fvec_renorm_L2 (d_in, n, x_norm.get());
903
+ fvec_renorm_L2(d_in, n, x_norm.get());
762
904
  }
763
905
 
764
906
  // train PCA
765
907
 
766
- PCAMatrix pca (d_in, d_out);
767
- float *x_pca;
768
- std::unique_ptr<float []> x_pca_del;
908
+ PCAMatrix pca(d_in, d_out);
909
+ float* x_pca;
910
+ std::unique_ptr<float[]> x_pca_del;
769
911
  if (do_pca) {
770
- pca.have_bias = false; // for consistency with reference implem
771
- pca.train (n, x_norm.get());
772
- x_pca = pca.apply (n, x_norm.get());
912
+ pca.have_bias = false; // for consistency with reference implem
913
+ pca.train(n, x_norm.get());
914
+ x_pca = pca.apply(n, x_norm.get());
773
915
  x_pca_del.reset(x_pca);
774
916
  } else {
775
917
  x_pca = x_norm.get();
776
918
  }
777
919
 
778
920
  // train ITQ
779
- itq.train (n, x_pca);
921
+ itq.train(n, x_pca);
780
922
 
781
923
  // merge PCA and ITQ
782
924
  if (do_pca) {
783
925
  FINTEGER di = d_out, dini = d_in;
784
926
  float one = 1, zero = 0;
785
927
  pca_then_itq.A.resize(d_in * d_out);
786
- sgemm_ ("N", "N", &dini, &di, &di,
787
- &one, pca.A.data(), &dini,
788
- itq.A.data(), &di,
789
- &zero, pca_then_itq.A.data(), &dini);
928
+ sgemm_("N",
929
+ "N",
930
+ &dini,
931
+ &di,
932
+ &di,
933
+ &one,
934
+ pca.A.data(),
935
+ &dini,
936
+ itq.A.data(),
937
+ &di,
938
+ &zero,
939
+ pca_then_itq.A.data(),
940
+ &dini);
790
941
  } else {
791
942
  pca_then_itq.A = itq.A;
792
943
  }
@@ -794,12 +945,11 @@ void ITQTransform::train (idx_t n, const float *x)
794
945
  is_trained = true;
795
946
  }
796
947
 
797
- void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
798
- float * xt) const
799
- {
948
+ void ITQTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
949
+ const {
800
950
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
801
951
 
802
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
952
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
803
953
  { // normalize
804
954
  int d = d_in;
805
955
  for (idx_t i = 0; i < n; i++) {
@@ -809,41 +959,36 @@ void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
809
959
  }
810
960
  // this is not really useful if we are going to binarize right
811
961
  // afterwards but OK
812
- fvec_renorm_L2 (d_in, n, x_norm.get());
962
+ fvec_renorm_L2(d_in, n, x_norm.get());
813
963
  }
814
964
 
815
- pca_then_itq.apply_noalloc (n, x_norm.get(), xt);
965
+ pca_then_itq.apply_noalloc(n, x_norm.get(), xt);
816
966
  }
817
967
 
818
968
  /*********************************************
819
969
  * OPQMatrix
820
970
  *********************************************/
821
971
 
822
-
823
- OPQMatrix::OPQMatrix (int d, int M, int d2):
824
- LinearTransform (d, d2 == -1 ? d : d2, false), M(M),
825
- niter (50),
826
- niter_pq (4), niter_pq_0 (40),
827
- verbose(false),
828
- pq(nullptr)
829
- {
972
+ OPQMatrix::OPQMatrix(int d, int M, int d2)
973
+ : LinearTransform(d, d2 == -1 ? d : d2, false),
974
+ M(M),
975
+ niter(50),
976
+ niter_pq(4),
977
+ niter_pq_0(40),
978
+ verbose(false),
979
+ pq(nullptr) {
830
980
  is_trained = false;
831
981
  // OPQ is quite expensive to train, so set this right.
832
982
  max_train_points = 256 * 256;
833
983
  pq = nullptr;
834
984
  }
835
985
 
986
+ void OPQMatrix::train(Index::idx_t n, const float* x) {
987
+ const float* x_in = x;
836
988
 
989
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x, verbose);
837
990
 
838
- void OPQMatrix::train (Index::idx_t n, const float *x)
839
- {
840
-
841
- const float * x_in = x;
842
-
843
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
844
- max_train_points, x, verbose);
845
-
846
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
991
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
847
992
 
848
993
  // To support d_out > d_in, we pad input vectors with 0s to d_out
849
994
  size_t d = d_out <= d_in ? d_in : d_out;
@@ -867,22 +1012,26 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
867
1012
  #endif
868
1013
 
869
1014
  if (verbose) {
870
- printf ("OPQMatrix::train: training an OPQ rotation matrix "
871
- "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
872
- M, n, d_in, d_out);
1015
+ printf("OPQMatrix::train: training an OPQ rotation matrix "
1016
+ "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
1017
+ M,
1018
+ n,
1019
+ d_in,
1020
+ d_out);
873
1021
  }
874
1022
 
875
- std::vector<float> xtrain (n * d);
1023
+ std::vector<float> xtrain(n * d);
876
1024
  // center x
877
1025
  {
878
- std::vector<float> sum (d);
879
- const float *xi = x;
1026
+ std::vector<float> sum(d);
1027
+ const float* xi = x;
880
1028
  for (size_t i = 0; i < n; i++) {
881
1029
  for (int j = 0; j < d_in; j++)
882
- sum [j] += *xi++;
1030
+ sum[j] += *xi++;
883
1031
  }
884
- for (int i = 0; i < d; i++) sum[i] /= n;
885
- float *yi = xtrain.data();
1032
+ for (int i = 0; i < d; i++)
1033
+ sum[i] /= n;
1034
+ float* yi = xtrain.data();
886
1035
  xi = x;
887
1036
  for (size_t i = 0; i < n; i++) {
888
1037
  for (int j = 0; j < d_in; j++)
@@ -890,71 +1039,80 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
890
1039
  yi += d - d_in;
891
1040
  }
892
1041
  }
893
- float *rotation;
1042
+ float* rotation;
894
1043
 
895
- if (A.size () == 0) {
896
- A.resize (d * d);
1044
+ if (A.size() == 0) {
1045
+ A.resize(d * d);
897
1046
  rotation = A.data();
898
1047
  if (verbose)
899
1048
  printf(" OPQMatrix::train: making random %zd*%zd rotation\n",
900
- d, d);
901
- float_randn (rotation, d * d, 1234);
902
- matrix_qr (d, d, rotation);
1049
+ d,
1050
+ d);
1051
+ float_randn(rotation, d * d, 1234);
1052
+ matrix_qr(d, d, rotation);
903
1053
  // we use only the d * d2 upper part of the matrix
904
- A.resize (d * d2);
1054
+ A.resize(d * d2);
905
1055
  } else {
906
- FAISS_THROW_IF_NOT (A.size() == d * d2);
1056
+ FAISS_THROW_IF_NOT(A.size() == d * d2);
907
1057
  rotation = A.data();
908
1058
  }
909
1059
 
910
- std::vector<float>
911
- xproj (d2 * n), pq_recons (d2 * n), xxr (d * n),
912
- tmp(d * d * 4);
913
-
1060
+ std::vector<float> xproj(d2 * n), pq_recons(d2 * n), xxr(d * n),
1061
+ tmp(d * d * 4);
914
1062
 
915
- ProductQuantizer pq_default (d2, M, 8);
916
- ProductQuantizer &pq_regular = pq ? *pq : pq_default;
917
- std::vector<uint8_t> codes (pq_regular.code_size * n);
1063
+ ProductQuantizer pq_default(d2, M, 8);
1064
+ ProductQuantizer& pq_regular = pq ? *pq : pq_default;
1065
+ std::vector<uint8_t> codes(pq_regular.code_size * n);
918
1066
 
919
1067
  double t0 = getmillisecs();
920
1068
  for (int iter = 0; iter < niter; iter++) {
921
-
922
1069
  { // torch.mm(xtrain, rotation:t())
923
1070
  FINTEGER di = d, d2i = d2, ni = n;
924
1071
  float zero = 0, one = 1;
925
- sgemm_ ("Transposed", "Not transposed",
926
- &d2i, &ni, &di,
927
- &one, rotation, &di,
928
- xtrain.data(), &di,
929
- &zero, xproj.data(), &d2i);
1072
+ sgemm_("Transposed",
1073
+ "Not transposed",
1074
+ &d2i,
1075
+ &ni,
1076
+ &di,
1077
+ &one,
1078
+ rotation,
1079
+ &di,
1080
+ xtrain.data(),
1081
+ &di,
1082
+ &zero,
1083
+ xproj.data(),
1084
+ &d2i);
930
1085
  }
931
1086
 
932
1087
  pq_regular.cp.max_points_per_centroid = 1000;
933
1088
  pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
934
1089
  pq_regular.verbose = verbose;
935
- pq_regular.train (n, xproj.data());
1090
+ pq_regular.train(n, xproj.data());
936
1091
 
937
1092
  if (verbose) {
938
1093
  printf(" encode / decode\n");
939
1094
  }
940
1095
  if (pq_regular.assign_index) {
941
- pq_regular.compute_codes_with_assign_index
942
- (xproj.data(), codes.data(), n);
1096
+ pq_regular.compute_codes_with_assign_index(
1097
+ xproj.data(), codes.data(), n);
943
1098
  } else {
944
- pq_regular.compute_codes (xproj.data(), codes.data(), n);
1099
+ pq_regular.compute_codes(xproj.data(), codes.data(), n);
945
1100
  }
946
- pq_regular.decode (codes.data(), pq_recons.data(), n);
1101
+ pq_regular.decode(codes.data(), pq_recons.data(), n);
947
1102
 
948
- float pq_err = fvec_L2sqr (pq_recons.data(), xproj.data(), n * d2) / n;
1103
+ float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
949
1104
 
950
1105
  if (verbose)
951
- printf (" Iteration %d (%d PQ iterations):"
952
- "%.3f s, obj=%g\n", iter, pq_regular.cp.niter,
953
- (getmillisecs () - t0) / 1000.0, pq_err);
1106
+ printf(" Iteration %d (%d PQ iterations):"
1107
+ "%.3f s, obj=%g\n",
1108
+ iter,
1109
+ pq_regular.cp.niter,
1110
+ (getmillisecs() - t0) / 1000.0,
1111
+ pq_err);
954
1112
 
955
1113
  {
956
- float *u = tmp.data(), *vt = &tmp [d * d];
957
- float *sing_val = &tmp [2 * d * d];
1114
+ float *u = tmp.data(), *vt = &tmp[d * d];
1115
+ float* sing_val = &tmp[2 * d * d];
958
1116
  FINTEGER di = d, d2i = d2, ni = n;
959
1117
  float one = 1, zero = 0;
960
1118
 
@@ -962,36 +1120,69 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
962
1120
  printf(" X * recons\n");
963
1121
  }
964
1122
  // torch.mm(xtrain:t(), pq_recons)
965
- sgemm_ ("Not", "Transposed",
966
- &d2i, &di, &ni,
967
- &one, pq_recons.data(), &d2i,
968
- xtrain.data(), &di,
969
- &zero, xxr.data(), &d2i);
970
-
1123
+ sgemm_("Not",
1124
+ "Transposed",
1125
+ &d2i,
1126
+ &di,
1127
+ &ni,
1128
+ &one,
1129
+ pq_recons.data(),
1130
+ &d2i,
1131
+ xtrain.data(),
1132
+ &di,
1133
+ &zero,
1134
+ xxr.data(),
1135
+ &d2i);
971
1136
 
972
1137
  FINTEGER lwork = -1, info = -1;
973
1138
  float worksz;
974
1139
  // workspace query
975
- sgesvd_ ("All", "All",
976
- &d2i, &di, xxr.data(), &d2i,
977
- sing_val,
978
- vt, &d2i, u, &di,
979
- &worksz, &lwork, &info);
1140
+ sgesvd_("All",
1141
+ "All",
1142
+ &d2i,
1143
+ &di,
1144
+ xxr.data(),
1145
+ &d2i,
1146
+ sing_val,
1147
+ vt,
1148
+ &d2i,
1149
+ u,
1150
+ &di,
1151
+ &worksz,
1152
+ &lwork,
1153
+ &info);
980
1154
 
981
1155
  lwork = int(worksz);
982
- std::vector<float> work (lwork);
1156
+ std::vector<float> work(lwork);
983
1157
  // u and vt swapped
984
- sgesvd_ ("All", "All",
985
- &d2i, &di, xxr.data(), &d2i,
986
- sing_val,
987
- vt, &d2i, u, &di,
988
- work.data(), &lwork, &info);
989
-
990
- sgemm_ ("Transposed", "Transposed",
991
- &di, &d2i, &d2i,
992
- &one, u, &di, vt, &d2i,
993
- &zero, rotation, &di);
994
-
1158
+ sgesvd_("All",
1159
+ "All",
1160
+ &d2i,
1161
+ &di,
1162
+ xxr.data(),
1163
+ &d2i,
1164
+ sing_val,
1165
+ vt,
1166
+ &d2i,
1167
+ u,
1168
+ &di,
1169
+ work.data(),
1170
+ &lwork,
1171
+ &info);
1172
+
1173
+ sgemm_("Transposed",
1174
+ "Transposed",
1175
+ &di,
1176
+ &d2i,
1177
+ &d2i,
1178
+ &one,
1179
+ u,
1180
+ &di,
1181
+ vt,
1182
+ &d2i,
1183
+ &zero,
1184
+ rotation,
1185
+ &di);
995
1186
  }
996
1187
  pq_regular.train_type = ProductQuantizer::Train_hot_start;
997
1188
  }
@@ -999,59 +1190,52 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
999
1190
  // revert A matrix
1000
1191
  if (d > d_in) {
1001
1192
  for (long i = 0; i < d_out; i++)
1002
- memmove (&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1003
- A.resize (d_in * d_out);
1193
+ memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1194
+ A.resize(d_in * d_out);
1004
1195
  }
1005
1196
 
1006
1197
  is_trained = true;
1007
1198
  is_orthonormal = true;
1008
1199
  }
1009
1200
 
1010
-
1011
1201
  /*********************************************
1012
1202
  * NormalizationTransform
1013
1203
  *********************************************/
1014
1204
 
1015
- NormalizationTransform::NormalizationTransform (int d, float norm):
1016
- VectorTransform (d, d), norm (norm)
1017
- {
1018
- }
1205
+ NormalizationTransform::NormalizationTransform(int d, float norm)
1206
+ : VectorTransform(d, d), norm(norm) {}
1019
1207
 
1020
- NormalizationTransform::NormalizationTransform ():
1021
- VectorTransform (-1, -1), norm (-1)
1022
- {
1023
- }
1208
+ NormalizationTransform::NormalizationTransform()
1209
+ : VectorTransform(-1, -1), norm(-1) {}
1024
1210
 
1025
- void NormalizationTransform::apply_noalloc
1026
- (idx_t n, const float* x, float* xt) const
1027
- {
1211
+ void NormalizationTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1212
+ const {
1028
1213
  if (norm == 2.0) {
1029
- memcpy (xt, x, sizeof (x[0]) * n * d_in);
1030
- fvec_renorm_L2 (d_in, n, xt);
1214
+ memcpy(xt, x, sizeof(x[0]) * n * d_in);
1215
+ fvec_renorm_L2(d_in, n, xt);
1031
1216
  } else {
1032
- FAISS_THROW_MSG ("not implemented");
1217
+ FAISS_THROW_MSG("not implemented");
1033
1218
  }
1034
1219
  }
1035
1220
 
1036
- void NormalizationTransform::reverse_transform (idx_t n, const float* xt,
1037
- float* x) const
1038
- {
1039
- memcpy (x, xt, sizeof (xt[0]) * n * d_in);
1221
+ void NormalizationTransform::reverse_transform(
1222
+ idx_t n,
1223
+ const float* xt,
1224
+ float* x) const {
1225
+ memcpy(x, xt, sizeof(xt[0]) * n * d_in);
1040
1226
  }
1041
1227
 
1042
1228
  /*********************************************
1043
1229
  * CenteringTransform
1044
1230
  *********************************************/
1045
1231
 
1046
- CenteringTransform::CenteringTransform (int d):
1047
- VectorTransform (d, d)
1048
- {
1232
+ CenteringTransform::CenteringTransform(int d) : VectorTransform(d, d) {
1049
1233
  is_trained = false;
1050
1234
  }
1051
1235
 
1052
- void CenteringTransform::train(Index::idx_t n, const float *x) {
1236
+ void CenteringTransform::train(Index::idx_t n, const float* x) {
1053
1237
  FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
1054
- mean.resize (d_in, 0);
1238
+ mean.resize(d_in, 0);
1055
1239
  for (idx_t i = 0; i < n; i++) {
1056
1240
  for (size_t j = 0; j < d_in; j++) {
1057
1241
  mean[j] += *x++;
@@ -1064,11 +1248,9 @@ void CenteringTransform::train(Index::idx_t n, const float *x) {
1064
1248
  is_trained = true;
1065
1249
  }
1066
1250
 
1067
-
1068
- void CenteringTransform::apply_noalloc
1069
- (idx_t n, const float* x, float* xt) const
1070
- {
1071
- FAISS_THROW_IF_NOT (is_trained);
1251
+ void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1252
+ const {
1253
+ FAISS_THROW_IF_NOT(is_trained);
1072
1254
 
1073
1255
  for (idx_t i = 0; i < n; i++) {
1074
1256
  for (size_t j = 0; j < d_in; j++) {
@@ -1077,64 +1259,58 @@ void CenteringTransform::apply_noalloc
1077
1259
  }
1078
1260
  }
1079
1261
 
1080
- void CenteringTransform::reverse_transform (idx_t n, const float* xt,
1081
- float* x) const
1082
- {
1083
- FAISS_THROW_IF_NOT (is_trained);
1262
+ void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
1263
+ const {
1264
+ FAISS_THROW_IF_NOT(is_trained);
1084
1265
 
1085
1266
  for (idx_t i = 0; i < n; i++) {
1086
1267
  for (size_t j = 0; j < d_in; j++) {
1087
1268
  *x++ = *xt++ + mean[j];
1088
1269
  }
1089
1270
  }
1090
-
1091
1271
  }
1092
1272
 
1093
-
1094
-
1095
-
1096
-
1097
1273
  /*********************************************
1098
1274
  * RemapDimensionsTransform
1099
1275
  *********************************************/
1100
1276
 
1101
-
1102
- RemapDimensionsTransform::RemapDimensionsTransform (
1103
- int d_in, int d_out, const int *map_in):
1104
- VectorTransform (d_in, d_out)
1105
- {
1106
- map.resize (d_out);
1277
+ RemapDimensionsTransform::RemapDimensionsTransform(
1278
+ int d_in,
1279
+ int d_out,
1280
+ const int* map_in)
1281
+ : VectorTransform(d_in, d_out) {
1282
+ map.resize(d_out);
1107
1283
  for (int i = 0; i < d_out; i++) {
1108
1284
  map[i] = map_in[i];
1109
- FAISS_THROW_IF_NOT (map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1285
+ FAISS_THROW_IF_NOT(map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1110
1286
  }
1111
1287
  }
1112
1288
 
1113
- RemapDimensionsTransform::RemapDimensionsTransform (
1114
- int d_in, int d_out, bool uniform): VectorTransform (d_in, d_out)
1115
- {
1116
- map.resize (d_out, -1);
1289
+ RemapDimensionsTransform::RemapDimensionsTransform(
1290
+ int d_in,
1291
+ int d_out,
1292
+ bool uniform)
1293
+ : VectorTransform(d_in, d_out) {
1294
+ map.resize(d_out, -1);
1117
1295
 
1118
1296
  if (uniform) {
1119
1297
  if (d_in < d_out) {
1120
1298
  for (int i = 0; i < d_in; i++) {
1121
- map [i * d_out / d_in] = i;
1122
- }
1299
+ map[i * d_out / d_in] = i;
1300
+ }
1123
1301
  } else {
1124
1302
  for (int i = 0; i < d_out; i++) {
1125
- map [i] = i * d_in / d_out;
1303
+ map[i] = i * d_in / d_out;
1126
1304
  }
1127
1305
  }
1128
1306
  } else {
1129
1307
  for (int i = 0; i < d_in && i < d_out; i++)
1130
- map [i] = i;
1308
+ map[i] = i;
1131
1309
  }
1132
1310
  }
1133
1311
 
1134
-
1135
- void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1136
- float *xt) const
1137
- {
1312
+ void RemapDimensionsTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1313
+ const {
1138
1314
  for (idx_t i = 0; i < n; i++) {
1139
1315
  for (int j = 0; j < d_out; j++) {
1140
1316
  xt[j] = map[j] < 0 ? 0 : x[map[j]];
@@ -1144,13 +1320,15 @@ void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1144
1320
  }
1145
1321
  }
1146
1322
 
1147
- void RemapDimensionsTransform::reverse_transform (idx_t n, const float * xt,
1148
- float *x) const
1149
- {
1150
- memset (x, 0, sizeof (*x) * n * d_in);
1323
+ void RemapDimensionsTransform::reverse_transform(
1324
+ idx_t n,
1325
+ const float* xt,
1326
+ float* x) const {
1327
+ memset(x, 0, sizeof(*x) * n * d_in);
1151
1328
  for (idx_t i = 0; i < n; i++) {
1152
1329
  for (int j = 0; j < d_out; j++) {
1153
- if (map[j] >= 0) x[map[j]] = xt[j];
1330
+ if (map[j] >= 0)
1331
+ x[map[j]] = xt[j];
1154
1332
  }
1155
1333
  x += d_in;
1156
1334
  xt += d_out;