faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,199 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <memory>
13
+ #include <mutex>
14
+ #include <vector>
15
+
16
+ #include <omp.h>
17
+
18
+ #include <faiss/Index.h>
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/utils/Heap.h>
22
+ #include <faiss/utils/random.h>
23
+
24
+ namespace faiss {
25
+
26
+ /** Implementation of the Navigating Spreading-out Graph (NSG)
27
+ * datastructure.
28
+ *
29
+ * Fast Approximate Nearest Neighbor Search With The
30
+ * Navigating Spreading-out Graph
31
+ *
32
+ * Cong Fu, Chao Xiang, Changxu Wang, Deng Cai, VLDB 2019
33
+ *
34
+ * This implementation is heavily influenced by the NSG
35
+ * implementation by ZJULearning Group
36
+ * (https://github.com/zjulearning/nsg)
37
+ *
38
+ * The NSG object stores only the neighbor link structure, see
39
+ * IndexNSG.h for the full index object.
40
+ */
41
+
42
+ struct DistanceComputer; // from AuxIndexStructures
43
+ struct Neighbor;
44
+ struct Node;
45
+
46
+ namespace nsg {
47
+
48
+ /***********************************************************
49
+ * Graph structure to store a graph.
50
+ *
51
+ * It is represented by an adjacency matrix `data`, where
52
+ * data[i, j] is the j-th neighbor of node i.
53
+ ***********************************************************/
54
+
55
+ template <class node_t>
56
+ struct Graph {
57
+ node_t* data; ///< the flattened adjacency matrix
58
+ int K; ///< nb of neighbors per node
59
+ int N; ///< total nb of nodes
60
+ bool own_fields; ///< the underlying data owned by itself or not
61
+
62
+ // construct from a known graph
63
+ Graph(node_t* data, int N, int K)
64
+ : data(data), K(K), N(N), own_fields(false) {}
65
+
66
+ // construct an empty graph
67
+ // NOTE: the newly allocated data needs to be destroyed at destruction time
68
+ Graph(int N, int K) : K(K), N(N), own_fields(true) {
69
+ data = new node_t[N * K];
70
+ }
71
+
72
+ // copy constructor
73
+ Graph(const Graph& g) : Graph(g.N, g.K) {
74
+ memcpy(data, g.data, N * K * sizeof(node_t));
75
+ }
76
+
77
+ // release the allocated memory if needed
78
+ ~Graph() {
79
+ if (own_fields) {
80
+ delete[] data;
81
+ }
82
+ }
83
+
84
+ // access the j-th neighbor of node i
85
+ inline node_t at(int i, int j) const {
86
+ return data[i * K + j];
87
+ }
88
+
89
+ // access the j-th neighbor of node i by reference
90
+ inline node_t& at(int i, int j) {
91
+ return data[i * K + j];
92
+ }
93
+ };
94
+
95
+ DistanceComputer* storage_distance_computer(const Index* storage);
96
+
97
+ } // namespace nsg
98
+
99
+ struct NSG {
100
+ /// internal storage of vectors (32 bits: this is expensive)
101
+ using storage_idx_t = int;
102
+
103
+ /// Faiss results are 64-bit
104
+ using idx_t = Index::idx_t;
105
+
106
+ int ntotal; ///< nb of nodes
107
+
108
+ /// construction-time parameters
109
+ int R; ///< nb of neighbors per node
110
+ int L; ///< length of the search path at construction time
111
+ int C; ///< candidate pool size at construction time
112
+
113
+ // search-time parameters
114
+ int search_L; ///< length of the search path
115
+
116
+ int enterpoint; ///< enterpoint
117
+
118
+ std::shared_ptr<nsg::Graph<int>> final_graph; ///< NSG graph structure
119
+
120
+ bool is_built; ///< NSG is built or not
121
+
122
+ RandomGenerator rng; ///< random generator
123
+
124
+ explicit NSG(int R = 32);
125
+
126
+ // build NSG from a KNN graph
127
+ void build(
128
+ Index* storage,
129
+ idx_t n,
130
+ const nsg::Graph<idx_t>& knn_graph,
131
+ bool verbose);
132
+
133
+ // reset the graph
134
+ void reset();
135
+
136
+ // search interface
137
+ void search(
138
+ DistanceComputer& dis,
139
+ int k,
140
+ idx_t* I,
141
+ float* D,
142
+ VisitedTable& vt) const;
143
+
144
+ // Compute the center point
145
+ void init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph);
146
+
147
+ // Search on a built graph.
148
+ // If collect_fullset is true, the visited nodes will be
149
+ // collected in `fullset`.
150
+ template <bool collect_fullset, class index_t>
151
+ void search_on_graph(
152
+ const nsg::Graph<index_t>& graph,
153
+ DistanceComputer& dis,
154
+ VisitedTable& vt,
155
+ int ep,
156
+ int pool_size,
157
+ std::vector<Neighbor>& retset,
158
+ std::vector<Node>& fullset) const;
159
+
160
+ // Add reverse links
161
+ void add_reverse_links(
162
+ int q,
163
+ std::vector<std::mutex>& locks,
164
+ DistanceComputer& dis,
165
+ nsg::Graph<Node>& graph);
166
+
167
+ void sync_prune(
168
+ int q,
169
+ std::vector<Node>& pool,
170
+ DistanceComputer& dis,
171
+ VisitedTable& vt,
172
+ const nsg::Graph<idx_t>& knn_graph,
173
+ nsg::Graph<Node>& graph);
174
+
175
+ void link(
176
+ Index* storage,
177
+ const nsg::Graph<idx_t>& knn_graph,
178
+ nsg::Graph<Node>& graph,
179
+ bool verbose);
180
+
181
+ // make NSG be fully connected
182
+ int tree_grow(Index* storage, std::vector<int>& degrees);
183
+
184
+ // count the size of the connected component
185
+ // using depth first search start by root
186
+ int dfs(VisitedTable& vt, int root, int cnt) const;
187
+
188
+ // attach one unlinked node
189
+ int attach_unlinked(
190
+ Index* storage,
191
+ VisitedTable& vt,
192
+ VisitedTable& vt2,
193
+ std::vector<int>& degrees);
194
+
195
+ // check the integrity of the NSG built
196
+ void check_graph() const;
197
+ };
198
+
199
+ } // namespace faiss
@@ -8,18 +8,21 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/impl/PolysemousTraining.h>
11
+ #include "faiss/impl/FaissAssert.h"
12
+
13
+ #include <omp.h>
14
+ #include <stdint.h>
11
15
 
12
- #include <cstdlib>
13
16
  #include <cmath>
17
+ #include <cstdlib>
14
18
  #include <cstring>
15
- #include <stdint.h>
16
19
 
17
20
  #include <algorithm>
18
21
 
19
- #include <faiss/utils/random.h>
20
- #include <faiss/utils/utils.h>
21
22
  #include <faiss/utils/distances.h>
22
23
  #include <faiss/utils/hamming.h>
24
+ #include <faiss/utils/random.h>
25
+ #include <faiss/utils/utils.h>
23
26
 
24
27
  #include <faiss/impl/FaissAssert.h>
25
28
 
@@ -29,16 +32,14 @@
29
32
 
30
33
  namespace faiss {
31
34
 
32
-
33
35
  /****************************************************
34
36
  * Optimization code
35
37
  ****************************************************/
36
38
 
37
- SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
38
- {
39
+ SimulatedAnnealingParameters::SimulatedAnnealingParameters() {
39
40
  // set some reasonable defaults for the optimization
40
41
  init_temperature = 0.7;
41
- temperature_decay = pow (0.9, 1/500.);
42
+ temperature_decay = pow(0.9, 1 / 500.);
42
43
  // reduce by a factor 0.9 every 500 it
43
44
  n_iter = 500000;
44
45
  n_redo = 2;
@@ -50,44 +51,37 @@ SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
50
51
 
51
52
  // what would the cost update be if iw and jw were swapped?
52
53
  // default implementation just computes both and computes the difference
53
- double PermutationObjective::cost_update (
54
- const int *perm, int iw, int jw) const
55
- {
56
- double orig_cost = compute_cost (perm);
54
+ double PermutationObjective::cost_update(const int* perm, int iw, int jw)
55
+ const {
56
+ double orig_cost = compute_cost(perm);
57
57
 
58
- std::vector<int> perm2 (n);
58
+ std::vector<int> perm2(n);
59
59
  for (int i = 0; i < n; i++)
60
60
  perm2[i] = perm[i];
61
61
  perm2[iw] = perm[jw];
62
62
  perm2[jw] = perm[iw];
63
63
 
64
- double new_cost = compute_cost (perm2.data());
64
+ double new_cost = compute_cost(perm2.data());
65
65
  return new_cost - orig_cost;
66
66
  }
67
67
 
68
-
69
-
70
-
71
- SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer (
72
- PermutationObjective *obj,
73
- const SimulatedAnnealingParameters &p):
74
- SimulatedAnnealingParameters (p),
75
- obj (obj),
76
- n(obj->n),
77
- logfile (nullptr)
78
- {
79
- rnd = new RandomGenerator (p.seed);
80
- FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
68
+ SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer(
69
+ PermutationObjective* obj,
70
+ const SimulatedAnnealingParameters& p)
71
+ : SimulatedAnnealingParameters(p),
72
+ obj(obj),
73
+ n(obj->n),
74
+ logfile(nullptr) {
75
+ rnd = new RandomGenerator(p.seed);
76
+ FAISS_THROW_IF_NOT(n < 100000 && n >= 0);
81
77
  }
82
78
 
83
- SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer ()
84
- {
79
+ SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer() {
85
80
  delete rnd;
86
81
  }
87
82
 
88
83
  // run the optimization and return the best result in best_perm
89
- double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
90
- {
84
+ double SimulatedAnnealingOptimizer::run_optimization(int* best_perm) {
91
85
  double min_cost = 1e30;
92
86
 
93
87
  // just do a few runs of the annealing and keep the lowest output cost
@@ -95,84 +89,89 @@ double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
95
89
  std::vector<int> perm(n);
96
90
  for (int i = 0; i < n; i++)
97
91
  perm[i] = i;
98
- if (init_random) {
92
+ if (init_random) {
99
93
  for (int i = 0; i < n; i++) {
100
- int j = i + rnd->rand_int (n - i);
101
- std::swap (perm[i], perm[j]);
94
+ int j = i + rnd->rand_int(n - i);
95
+ std::swap(perm[i], perm[j]);
102
96
  }
103
97
  }
104
- float cost = optimize (perm.data());
105
- if (logfile) fprintf (logfile, "\n");
106
- if(verbose > 1) {
107
- printf (" optimization run %d: cost=%g %s\n",
108
- it, cost, cost < min_cost ? "keep" : "");
98
+ float cost = optimize(perm.data());
99
+ if (logfile)
100
+ fprintf(logfile, "\n");
101
+ if (verbose > 1) {
102
+ printf(" optimization run %d: cost=%g %s\n",
103
+ it,
104
+ cost,
105
+ cost < min_cost ? "keep" : "");
109
106
  }
110
107
  if (cost < min_cost) {
111
- memcpy (best_perm, perm.data(), sizeof(perm[0]) * n);
108
+ memcpy(best_perm, perm.data(), sizeof(perm[0]) * n);
112
109
  min_cost = cost;
113
110
  }
114
111
  }
115
- return min_cost;
112
+ return min_cost;
116
113
  }
117
114
 
118
115
  // perform the optimization loop, starting from and modifying
119
116
  // permutation in-place
120
- double SimulatedAnnealingOptimizer::optimize (int *perm)
121
- {
122
- double cost = init_cost = obj->compute_cost (perm);
117
+ double SimulatedAnnealingOptimizer::optimize(int* perm) {
118
+ double cost = init_cost = obj->compute_cost(perm);
123
119
  int log2n = 0;
124
- while (!(n <= (1 << log2n))) log2n++;
120
+ while (!(n <= (1 << log2n)))
121
+ log2n++;
125
122
  double temperature = init_temperature;
126
- int n_swap = 0, n_hot = 0;
123
+ int n_swap = 0, n_hot = 0;
127
124
  for (int it = 0; it < n_iter; it++) {
128
125
  temperature = temperature * temperature_decay;
129
126
  int iw, jw;
130
127
  if (only_bit_flips) {
131
- iw = rnd->rand_int (n);
132
- jw = iw ^ (1 << rnd->rand_int (log2n));
128
+ iw = rnd->rand_int(n);
129
+ jw = iw ^ (1 << rnd->rand_int(log2n));
133
130
  } else {
134
- iw = rnd->rand_int (n);
135
- jw = rnd->rand_int (n - 1);
136
- if (jw == iw) jw++;
131
+ iw = rnd->rand_int(n);
132
+ jw = rnd->rand_int(n - 1);
133
+ if (jw == iw)
134
+ jw++;
137
135
  }
138
- double delta_cost = obj->cost_update (perm, iw, jw);
139
- if (delta_cost < 0 || rnd->rand_float () < temperature) {
140
- std::swap (perm[iw], perm[jw]);
136
+ double delta_cost = obj->cost_update(perm, iw, jw);
137
+ if (delta_cost < 0 || rnd->rand_float() < temperature) {
138
+ std::swap(perm[iw], perm[jw]);
141
139
  cost += delta_cost;
142
140
  n_swap++;
143
- if (delta_cost >= 0) n_hot++;
141
+ if (delta_cost >= 0)
142
+ n_hot++;
144
143
  }
145
- if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
146
- printf (" iteration %d cost %g temp %g n_swap %d "
147
- "(%d hot) \r",
148
- it, cost, temperature, n_swap, n_hot);
144
+ if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
145
+ printf(" iteration %d cost %g temp %g n_swap %d "
146
+ "(%d hot) \r",
147
+ it,
148
+ cost,
149
+ temperature,
150
+ n_swap,
151
+ n_hot);
149
152
  fflush(stdout);
150
153
  }
151
154
  if (logfile) {
152
- fprintf (logfile, "%d %g %g %d %d\n",
153
- it, cost, temperature, n_swap, n_hot);
155
+ fprintf(logfile,
156
+ "%d %g %g %d %d\n",
157
+ it,
158
+ cost,
159
+ temperature,
160
+ n_swap,
161
+ n_hot);
154
162
  }
155
- }
156
- if (verbose > 1) printf("\n");
163
+ }
164
+ if (verbose > 1)
165
+ printf("\n");
157
166
  return cost;
158
167
  }
159
168
 
160
-
161
-
162
-
163
-
164
169
  /****************************************************
165
170
  * Cost functions: ReproduceDistanceTable
166
171
  ****************************************************/
167
172
 
168
-
169
-
170
-
171
-
172
-
173
- static inline int hamming_dis (uint64_t a, uint64_t b)
174
- {
175
- return __builtin_popcountl (a ^ b);
173
+ static inline int hamming_dis(uint64_t a, uint64_t b) {
174
+ return __builtin_popcountl(a ^ b);
176
175
  }
177
176
 
178
177
  namespace {
@@ -182,14 +181,14 @@ struct ReproduceWithHammingObjective : PermutationObjective {
182
181
  int nbits;
183
182
  double dis_weight_factor;
184
183
 
185
- static double sqr (double x) { return x * x; }
186
-
184
+ static double sqr(double x) {
185
+ return x * x;
186
+ }
187
187
 
188
188
  // weihgting of distances: it is more important to reproduce small
189
189
  // distances well
190
- double dis_weight (double x) const
191
- {
192
- return exp (-dis_weight_factor * x);
190
+ double dis_weight(double x) const {
191
+ return exp(-dis_weight_factor * x);
193
192
  }
194
193
 
195
194
  std::vector<double> target_dis; // wanted distances (size n^2)
@@ -197,101 +196,105 @@ struct ReproduceWithHammingObjective : PermutationObjective {
197
196
 
198
197
  // cost = quadratic difference between actual distance and Hamming distance
199
198
  double compute_cost(const int* perm) const override {
200
- double cost = 0;
201
- for (int i = 0; i < n; i++) {
202
- for (int j = 0; j < n; j++) {
203
- double wanted = target_dis[i * n + j];
204
- double w = weights[i * n + j];
205
- double actual = hamming_dis(perm[i], perm[j]);
206
- cost += w * sqr(wanted - actual);
199
+ double cost = 0;
200
+ for (int i = 0; i < n; i++) {
201
+ for (int j = 0; j < n; j++) {
202
+ double wanted = target_dis[i * n + j];
203
+ double w = weights[i * n + j];
204
+ double actual = hamming_dis(perm[i], perm[j]);
205
+ cost += w * sqr(wanted - actual);
206
+ }
207
207
  }
208
- }
209
- return cost;
208
+ return cost;
210
209
  }
211
210
 
212
-
213
211
  // what would the cost update be if iw and jw were swapped?
214
212
  // computed in O(n) instead of O(n^2) for the full re-computation
215
213
  double cost_update(const int* perm, int iw, int jw) const override {
216
- double delta_cost = 0;
214
+ double delta_cost = 0;
217
215
 
218
- for (int i = 0; i < n; i++) {
219
- if (i == iw) {
220
- for (int j = 0; j < n; j++) {
221
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
222
- double actual = hamming_dis(perm[i], perm[j]);
223
- delta_cost -= w * sqr(wanted - actual);
224
- double new_actual =
225
- hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]);
226
- delta_cost += w * sqr(wanted - new_actual);
227
- }
228
- } else if (i == jw) {
229
- for (int j = 0; j < n; j++) {
230
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
231
- double actual = hamming_dis(perm[i], perm[j]);
232
- delta_cost -= w * sqr(wanted - actual);
233
- double new_actual =
234
- hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]);
235
- delta_cost += w * sqr(wanted - new_actual);
236
- }
237
- } else {
238
- int j = iw;
239
- {
240
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
241
- double actual = hamming_dis(perm[i], perm[j]);
242
- delta_cost -= w * sqr(wanted - actual);
243
- double new_actual = hamming_dis(perm[i], perm[jw]);
244
- delta_cost += w * sqr(wanted - new_actual);
245
- }
246
- j = jw;
247
- {
248
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
249
- double actual = hamming_dis(perm[i], perm[j]);
250
- delta_cost -= w * sqr(wanted - actual);
251
- double new_actual = hamming_dis(perm[i], perm[iw]);
252
- delta_cost += w * sqr(wanted - new_actual);
253
- }
216
+ for (int i = 0; i < n; i++) {
217
+ if (i == iw) {
218
+ for (int j = 0; j < n; j++) {
219
+ double wanted = target_dis[i * n + j],
220
+ w = weights[i * n + j];
221
+ double actual = hamming_dis(perm[i], perm[j]);
222
+ delta_cost -= w * sqr(wanted - actual);
223
+ double new_actual = hamming_dis(
224
+ perm[jw],
225
+ perm[j == iw ? jw
226
+ : j == jw ? iw
227
+ : j]);
228
+ delta_cost += w * sqr(wanted - new_actual);
229
+ }
230
+ } else if (i == jw) {
231
+ for (int j = 0; j < n; j++) {
232
+ double wanted = target_dis[i * n + j],
233
+ w = weights[i * n + j];
234
+ double actual = hamming_dis(perm[i], perm[j]);
235
+ delta_cost -= w * sqr(wanted - actual);
236
+ double new_actual = hamming_dis(
237
+ perm[iw],
238
+ perm[j == iw ? jw
239
+ : j == jw ? iw
240
+ : j]);
241
+ delta_cost += w * sqr(wanted - new_actual);
242
+ }
243
+ } else {
244
+ int j = iw;
245
+ {
246
+ double wanted = target_dis[i * n + j],
247
+ w = weights[i * n + j];
248
+ double actual = hamming_dis(perm[i], perm[j]);
249
+ delta_cost -= w * sqr(wanted - actual);
250
+ double new_actual = hamming_dis(perm[i], perm[jw]);
251
+ delta_cost += w * sqr(wanted - new_actual);
252
+ }
253
+ j = jw;
254
+ {
255
+ double wanted = target_dis[i * n + j],
256
+ w = weights[i * n + j];
257
+ double actual = hamming_dis(perm[i], perm[j]);
258
+ delta_cost -= w * sqr(wanted - actual);
259
+ double new_actual = hamming_dis(perm[i], perm[iw]);
260
+ delta_cost += w * sqr(wanted - new_actual);
261
+ }
262
+ }
254
263
  }
255
- }
256
264
 
257
- return delta_cost;
265
+ return delta_cost;
258
266
  }
259
267
 
260
-
261
-
262
- ReproduceWithHammingObjective (
263
- int nbits,
264
- const std::vector<double> & dis_table,
265
- double dis_weight_factor):
266
- nbits (nbits), dis_weight_factor (dis_weight_factor)
267
- {
268
+ ReproduceWithHammingObjective(
269
+ int nbits,
270
+ const std::vector<double>& dis_table,
271
+ double dis_weight_factor)
272
+ : nbits(nbits), dis_weight_factor(dis_weight_factor) {
268
273
  n = 1 << nbits;
269
- FAISS_THROW_IF_NOT (dis_table.size() == n * n);
270
- set_affine_target_dis (dis_table);
274
+ FAISS_THROW_IF_NOT(dis_table.size() == n * n);
275
+ set_affine_target_dis(dis_table);
271
276
  }
272
277
 
273
- void set_affine_target_dis (const std::vector<double> & dis_table)
274
- {
278
+ void set_affine_target_dis(const std::vector<double>& dis_table) {
275
279
  double sum = 0, sum2 = 0;
276
280
  int n2 = n * n;
277
281
  for (int i = 0; i < n2; i++) {
278
- sum += dis_table [i];
279
- sum2 += dis_table [i] * dis_table [i];
282
+ sum += dis_table[i];
283
+ sum2 += dis_table[i] * dis_table[i];
280
284
  }
281
285
  double mean = sum / n2;
282
286
  double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
283
287
 
284
- target_dis.resize (n2);
288
+ target_dis.resize(n2);
285
289
 
286
290
  for (int i = 0; i < n2; i++) {
287
291
  // the mapping function
288
- double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) +
289
- nbits / 2;
292
+ double td = (dis_table[i] - mean) / stddev * sqrt(nbits / 4) +
293
+ nbits / 2;
290
294
  target_dis[i] = td;
291
295
  // compute a weight
292
- weights.push_back (dis_weight (td));
296
+ weights.push_back(dis_weight(td));
293
297
  }
294
-
295
298
  }
296
299
 
297
300
  ~ReproduceWithHammingObjective() override {}
@@ -301,27 +304,23 @@ struct ReproduceWithHammingObjective : PermutationObjective {
301
304
 
302
305
  // weihgting of distances: it is more important to reproduce small
303
306
  // distances well
304
- double ReproduceDistancesObjective::dis_weight (double x) const
305
- {
306
- return exp (-dis_weight_factor * x);
307
+ double ReproduceDistancesObjective::dis_weight(double x) const {
308
+ return exp(-dis_weight_factor * x);
307
309
  }
308
310
 
309
-
310
- double ReproduceDistancesObjective::get_source_dis (int i, int j) const
311
- {
312
- return source_dis [i * n + j];
311
+ double ReproduceDistancesObjective::get_source_dis(int i, int j) const {
312
+ return source_dis[i * n + j];
313
313
  }
314
314
 
315
315
  // cost = quadratic difference between actual distance and Hamming distance
316
- double ReproduceDistancesObjective::compute_cost (const int *perm) const
317
- {
316
+ double ReproduceDistancesObjective::compute_cost(const int* perm) const {
318
317
  double cost = 0;
319
318
  for (int i = 0; i < n; i++) {
320
319
  for (int j = 0; j < n; j++) {
321
- double wanted = target_dis [i * n + j];
322
- double w = weights [i * n + j];
323
- double actual = get_source_dis (perm[i], perm[j]);
324
- cost += w * sqr (wanted - actual);
320
+ double wanted = target_dis[i * n + j];
321
+ double w = weights[i * n + j];
322
+ double actual = get_source_dis(perm[i], perm[j]);
323
+ cost += w * sqr(wanted - actual);
325
324
  }
326
325
  }
327
326
  return cost;
@@ -329,79 +328,75 @@ double ReproduceDistancesObjective::compute_cost (const int *perm) const
329
328
 
330
329
  // what would the cost update be if iw and jw were swapped?
331
330
  // computed in O(n) instead of O(n^2) for the full re-computation
332
- double ReproduceDistancesObjective::cost_update(
333
- const int *perm, int iw, int jw) const
334
- {
331
+ double ReproduceDistancesObjective::cost_update(const int* perm, int iw, int jw)
332
+ const {
335
333
  double delta_cost = 0;
336
- for (int i = 0; i < n; i++) {
334
+ for (int i = 0; i < n; i++) {
337
335
  if (i == iw) {
338
336
  for (int j = 0; j < n; j++) {
339
- double wanted = target_dis [i * n + j],
340
- w = weights [i * n + j];
341
- double actual = get_source_dis (perm[i], perm[j]);
342
- delta_cost -= w * sqr (wanted - actual);
343
- double new_actual = get_source_dis (
344
- perm[jw],
345
- perm[j == iw ? jw : j == jw ? iw : j]);
346
- delta_cost += w * sqr (wanted - new_actual);
337
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
338
+ double actual = get_source_dis(perm[i], perm[j]);
339
+ delta_cost -= w * sqr(wanted - actual);
340
+ double new_actual = get_source_dis(
341
+ perm[jw],
342
+ perm[j == iw ? jw
343
+ : j == jw ? iw
344
+ : j]);
345
+ delta_cost += w * sqr(wanted - new_actual);
347
346
  }
348
347
  } else if (i == jw) {
349
348
  for (int j = 0; j < n; j++) {
350
- double wanted = target_dis [i * n + j],
351
- w = weights [i * n + j];
352
- double actual = get_source_dis (perm[i], perm[j]);
353
- delta_cost -= w * sqr (wanted - actual);
354
- double new_actual = get_source_dis (
355
- perm[iw],
356
- perm[j == iw ? jw : j == jw ? iw : j]);
357
- delta_cost += w * sqr (wanted - new_actual);
349
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
350
+ double actual = get_source_dis(perm[i], perm[j]);
351
+ delta_cost -= w * sqr(wanted - actual);
352
+ double new_actual = get_source_dis(
353
+ perm[iw],
354
+ perm[j == iw ? jw
355
+ : j == jw ? iw
356
+ : j]);
357
+ delta_cost += w * sqr(wanted - new_actual);
358
358
  }
359
- } else {
359
+ } else {
360
360
  int j = iw;
361
361
  {
362
- double wanted = target_dis [i * n + j],
363
- w = weights [i * n + j];
364
- double actual = get_source_dis (perm[i], perm[j]);
365
- delta_cost -= w * sqr (wanted - actual);
366
- double new_actual = get_source_dis (perm[i], perm[jw]);
367
- delta_cost += w * sqr (wanted - new_actual);
362
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
363
+ double actual = get_source_dis(perm[i], perm[j]);
364
+ delta_cost -= w * sqr(wanted - actual);
365
+ double new_actual = get_source_dis(perm[i], perm[jw]);
366
+ delta_cost += w * sqr(wanted - new_actual);
368
367
  }
369
368
  j = jw;
370
369
  {
371
- double wanted = target_dis [i * n + j],
372
- w = weights [i * n + j];
373
- double actual = get_source_dis (perm[i], perm[j]);
374
- delta_cost -= w * sqr (wanted - actual);
375
- double new_actual = get_source_dis (perm[i], perm[iw]);
376
- delta_cost += w * sqr (wanted - new_actual);
370
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
371
+ double actual = get_source_dis(perm[i], perm[j]);
372
+ delta_cost -= w * sqr(wanted - actual);
373
+ double new_actual = get_source_dis(perm[i], perm[iw]);
374
+ delta_cost += w * sqr(wanted - new_actual);
377
375
  }
378
376
  }
379
377
  }
380
- return delta_cost;
378
+ return delta_cost;
381
379
  }
382
380
 
383
-
384
-
385
- ReproduceDistancesObjective::ReproduceDistancesObjective (
386
- int n,
387
- const double *source_dis_in,
388
- const double *target_dis_in,
389
- double dis_weight_factor):
390
- dis_weight_factor (dis_weight_factor),
391
- target_dis (target_dis_in)
392
- {
381
+ ReproduceDistancesObjective::ReproduceDistancesObjective(
382
+ int n,
383
+ const double* source_dis_in,
384
+ const double* target_dis_in,
385
+ double dis_weight_factor)
386
+ : dis_weight_factor(dis_weight_factor), target_dis(target_dis_in) {
393
387
  this->n = n;
394
- set_affine_target_dis (source_dis_in);
388
+ set_affine_target_dis(source_dis_in);
395
389
  }
396
390
 
397
- void ReproduceDistancesObjective::compute_mean_stdev (
398
- const double *tab, size_t n2,
399
- double *mean_out, double *stddev_out)
400
- {
391
+ void ReproduceDistancesObjective::compute_mean_stdev(
392
+ const double* tab,
393
+ size_t n2,
394
+ double* mean_out,
395
+ double* stddev_out) {
401
396
  double sum = 0, sum2 = 0;
402
397
  for (int i = 0; i < n2; i++) {
403
- sum += tab [i];
404
- sum2 += tab [i] * tab [i];
398
+ sum += tab[i];
399
+ sum2 += tab[i] * tab[i];
405
400
  }
406
401
  double mean = sum / n2;
407
402
  double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
@@ -409,32 +404,34 @@ void ReproduceDistancesObjective::compute_mean_stdev (
409
404
  *stddev_out = stddev;
410
405
  }
411
406
 
412
- void ReproduceDistancesObjective::set_affine_target_dis (
413
- const double *source_dis_in)
414
- {
407
+ void ReproduceDistancesObjective::set_affine_target_dis(
408
+ const double* source_dis_in) {
415
409
  int n2 = n * n;
416
410
 
417
411
  double mean_src, stddev_src;
418
- compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src);
412
+ compute_mean_stdev(source_dis_in, n2, &mean_src, &stddev_src);
419
413
 
420
414
  double mean_target, stddev_target;
421
- compute_mean_stdev (target_dis, n2, &mean_target, &stddev_target);
415
+ compute_mean_stdev(target_dis, n2, &mean_target, &stddev_target);
422
416
 
423
- printf ("map mean %g std %g -> mean %g std %g\n",
424
- mean_src, stddev_src, mean_target, stddev_target);
417
+ printf("map mean %g std %g -> mean %g std %g\n",
418
+ mean_src,
419
+ stddev_src,
420
+ mean_target,
421
+ stddev_target);
425
422
 
426
- source_dis.resize (n2);
427
- weights.resize (n2);
423
+ source_dis.resize(n2);
424
+ weights.resize(n2);
428
425
 
429
426
  for (int i = 0; i < n2; i++) {
430
427
  // the mapping function
431
- source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src
432
- * stddev_target + mean_target;
428
+ source_dis[i] =
429
+ (source_dis_in[i] - mean_src) / stddev_src * stddev_target +
430
+ mean_target;
433
431
 
434
432
  // compute a weight
435
- weights [i] = dis_weight (target_dis[i]);
433
+ weights[i] = dis_weight(target_dis[i]);
436
434
  }
437
-
438
435
  }
439
436
 
440
437
  /****************************************************
@@ -444,8 +441,7 @@ void ReproduceDistancesObjective::set_affine_target_dis (
444
441
  /// Maintains a 3D table of elementary costs.
445
442
  /// Accumulates elements based on Hamming distance comparisons
446
443
  template <typename Ttab, typename Taccu>
447
- struct Score3Computer: PermutationObjective {
448
-
444
+ struct Score3Computer : PermutationObjective {
449
445
  int nc;
450
446
 
451
447
  // cost matrix of size nc * nc *nc
@@ -453,21 +449,18 @@ struct Score3Computer: PermutationObjective {
453
449
  // where x has PQ code i, y- PQ code j and y+ PQ code k
454
450
  std::vector<Ttab> n_gt;
455
451
 
456
-
457
452
  /// the cost is a triple loop on the nc * nc * nc matrix of entries.
458
453
  ///
459
- Taccu compute (const int * perm) const
460
- {
454
+ Taccu compute(const int* perm) const {
461
455
  Taccu accu = 0;
462
- const Ttab *p = n_gt.data();
456
+ const Ttab* p = n_gt.data();
463
457
  for (int i = 0; i < nc; i++) {
464
- int ip = perm [i];
458
+ int ip = perm[i];
465
459
  for (int j = 0; j < nc; j++) {
466
- int jp = perm [j];
460
+ int jp = perm[j];
467
461
  for (int k = 0; k < nc; k++) {
468
- int kp = perm [k];
469
- if (hamming_dis (ip, jp) <
470
- hamming_dis (ip, kp)) {
462
+ int kp = perm[k];
463
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
471
464
  accu += *p; // n_gt [ ( i * nc + j) * nc + k];
472
465
  }
473
466
  p++;
@@ -477,7 +470,6 @@ struct Score3Computer: PermutationObjective {
477
470
  return accu;
478
471
  }
479
472
 
480
-
481
473
  /** cost update if entries iw and jw of the permutation would be
482
474
  * swapped.
483
475
  *
@@ -487,25 +479,23 @@ struct Score3Computer: PermutationObjective {
487
479
  * cells. Practical speedup is about 8x, and the code is quite
488
480
  * complex :-/
489
481
  */
490
- Taccu compute_update (const int *perm, int iw, int jw) const
491
- {
492
- assert (iw != jw);
493
- if (iw > jw) std::swap (iw, jw);
482
+ Taccu compute_update(const int* perm, int iw, int jw) const {
483
+ assert(iw != jw);
484
+ if (iw > jw)
485
+ std::swap(iw, jw);
494
486
 
495
487
  Taccu accu = 0;
496
- const Ttab * n_gt_i = n_gt.data();
488
+ const Ttab* n_gt_i = n_gt.data();
497
489
  for (int i = 0; i < nc; i++) {
498
- int ip0 = perm [i];
499
- int ip = perm [i == iw ? jw : i == jw ? iw : i];
490
+ int ip0 = perm[i];
491
+ int ip = perm[i == iw ? jw : i == jw ? iw : i];
500
492
 
501
- //accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
493
+ // accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
502
494
 
503
- accu += update_i_cross (perm, iw, jw,
504
- ip0, ip, n_gt_i);
495
+ accu += update_i_cross(perm, iw, jw, ip0, ip, n_gt_i);
505
496
 
506
497
  if (ip != ip0)
507
- accu += update_i_plane (perm, iw, jw,
508
- ip0, ip, n_gt_i);
498
+ accu += update_i_plane(perm, iw, jw, ip0, ip, n_gt_i);
509
499
 
510
500
  n_gt_i += nc * nc;
511
501
  }
@@ -513,23 +503,26 @@ struct Score3Computer: PermutationObjective {
513
503
  return accu;
514
504
  }
515
505
 
516
-
517
- Taccu update_i (const int *perm, int iw, int jw,
518
- int ip0, int ip, const Ttab * n_gt_i) const
519
- {
506
+ Taccu update_i(
507
+ const int* perm,
508
+ int iw,
509
+ int jw,
510
+ int ip0,
511
+ int ip,
512
+ const Ttab* n_gt_i) const {
520
513
  Taccu accu = 0;
521
- const Ttab *n_gt_ij = n_gt_i;
514
+ const Ttab* n_gt_ij = n_gt_i;
522
515
  for (int j = 0; j < nc; j++) {
523
516
  int jp0 = perm[j];
524
- int jp = perm [j == iw ? jw : j == jw ? iw : j];
517
+ int jp = perm[j == iw ? jw : j == jw ? iw : j];
525
518
  for (int k = 0; k < nc; k++) {
526
- int kp0 = perm [k];
527
- int kp = perm [k == iw ? jw : k == jw ? iw : k];
528
- int ng = n_gt_ij [k];
529
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
519
+ int kp0 = perm[k];
520
+ int kp = perm[k == iw ? jw : k == jw ? iw : k];
521
+ int ng = n_gt_ij[k];
522
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
530
523
  accu += ng;
531
524
  }
532
- if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
525
+ if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
533
526
  accu -= ng;
534
527
  }
535
528
  }
@@ -539,23 +532,27 @@ struct Score3Computer: PermutationObjective {
539
532
  }
540
533
 
541
534
  // 2 inner loops for the case ip0 != ip
542
- Taccu update_i_plane (const int *perm, int iw, int jw,
543
- int ip0, int ip, const Ttab * n_gt_i) const
544
- {
535
+ Taccu update_i_plane(
536
+ const int* perm,
537
+ int iw,
538
+ int jw,
539
+ int ip0,
540
+ int ip,
541
+ const Ttab* n_gt_i) const {
545
542
  Taccu accu = 0;
546
- const Ttab *n_gt_ij = n_gt_i;
543
+ const Ttab* n_gt_ij = n_gt_i;
547
544
 
548
545
  for (int j = 0; j < nc; j++) {
549
546
  if (j != iw && j != jw) {
550
547
  int jp = perm[j];
551
548
  for (int k = 0; k < nc; k++) {
552
549
  if (k != iw && k != jw) {
553
- int kp = perm [k];
554
- Ttab ng = n_gt_ij [k];
555
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
550
+ int kp = perm[k];
551
+ Ttab ng = n_gt_ij[k];
552
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
556
553
  accu += ng;
557
554
  }
558
- if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) {
555
+ if (hamming_dis(ip0, jp) < hamming_dis(ip0, kp)) {
559
556
  accu -= ng;
560
557
  }
561
558
  }
@@ -567,114 +564,128 @@ struct Score3Computer: PermutationObjective {
567
564
  }
568
565
 
569
566
  /// used for the 8 cells were the 3 indices are swapped
570
- inline Taccu update_k (const int *perm, int iw, int jw,
571
- int ip0, int ip, int jp0, int jp,
572
- int k,
573
- const Ttab * n_gt_ij) const
574
- {
567
+ inline Taccu update_k(
568
+ const int* perm,
569
+ int iw,
570
+ int jw,
571
+ int ip0,
572
+ int ip,
573
+ int jp0,
574
+ int jp,
575
+ int k,
576
+ const Ttab* n_gt_ij) const {
575
577
  Taccu accu = 0;
576
- int kp0 = perm [k];
577
- int kp = perm [k == iw ? jw : k == jw ? iw : k];
578
- Ttab ng = n_gt_ij [k];
579
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
578
+ int kp0 = perm[k];
579
+ int kp = perm[k == iw ? jw : k == jw ? iw : k];
580
+ Ttab ng = n_gt_ij[k];
581
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
580
582
  accu += ng;
581
583
  }
582
- if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
584
+ if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
583
585
  accu -= ng;
584
586
  }
585
587
  return accu;
586
588
  }
587
589
 
588
590
  /// compute update on a line of k's, where i and j are swapped
589
- Taccu update_j_line (const int *perm, int iw, int jw,
590
- int ip0, int ip, int jp0, int jp,
591
- const Ttab * n_gt_ij) const
592
- {
591
+ Taccu update_j_line(
592
+ const int* perm,
593
+ int iw,
594
+ int jw,
595
+ int ip0,
596
+ int ip,
597
+ int jp0,
598
+ int jp,
599
+ const Ttab* n_gt_ij) const {
593
600
  Taccu accu = 0;
594
601
  for (int k = 0; k < nc; k++) {
595
- if (k == iw || k == jw) continue;
596
- int kp = perm [k];
597
- Ttab ng = n_gt_ij [k];
598
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
602
+ if (k == iw || k == jw)
603
+ continue;
604
+ int kp = perm[k];
605
+ Ttab ng = n_gt_ij[k];
606
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
599
607
  accu += ng;
600
608
  }
601
- if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) {
609
+ if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp)) {
602
610
  accu -= ng;
603
611
  }
604
612
  }
605
613
  return accu;
606
614
  }
607
615
 
608
-
609
616
  /// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
610
- Taccu update_i_cross (const int *perm, int iw, int jw,
611
- int ip0, int ip, const Ttab * n_gt_i) const
612
- {
617
+ Taccu update_i_cross(
618
+ const int* perm,
619
+ int iw,
620
+ int jw,
621
+ int ip0,
622
+ int ip,
623
+ const Ttab* n_gt_i) const {
613
624
  Taccu accu = 0;
614
- const Ttab *n_gt_ij = n_gt_i;
625
+ const Ttab* n_gt_ij = n_gt_i;
615
626
 
616
627
  for (int j = 0; j < nc; j++) {
617
628
  int jp0 = perm[j];
618
- int jp = perm [j == iw ? jw : j == jw ? iw : j];
629
+ int jp = perm[j == iw ? jw : j == jw ? iw : j];
619
630
 
620
- accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
621
- accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
631
+ accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
632
+ accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
622
633
 
623
634
  if (jp != jp0)
624
- accu += update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
635
+ accu += update_j_line(perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
625
636
 
626
637
  n_gt_ij += nc;
627
638
  }
628
639
  return accu;
629
640
  }
630
641
 
631
-
632
642
  /// PermutationObjective implementeation (just negates the scores
633
643
  /// for minimization)
634
644
 
635
645
  double compute_cost(const int* perm) const override {
636
- return -compute(perm);
646
+ return -compute(perm);
637
647
  }
638
648
 
639
649
  double cost_update(const int* perm, int iw, int jw) const override {
640
- double ret = -compute_update(perm, iw, jw);
641
- return ret;
650
+ double ret = -compute_update(perm, iw, jw);
651
+ return ret;
642
652
  }
643
653
 
644
654
  ~Score3Computer() override {}
645
655
  };
646
656
 
647
-
648
-
649
-
650
-
651
657
  struct IndirectSort {
652
- const float *tab;
653
- bool operator () (int a, int b) {return tab[a] < tab[b]; }
658
+ const float* tab;
659
+ bool operator()(int a, int b) {
660
+ return tab[a] < tab[b];
661
+ }
654
662
  };
655
663
 
656
-
657
-
658
- struct RankingScore2: Score3Computer<float, double> {
664
+ struct RankingScore2 : Score3Computer<float, double> {
659
665
  int nbits;
660
666
  int nq, nb;
661
667
  const uint32_t *qcodes, *bcodes;
662
- const float *gt_distances;
663
-
664
- RankingScore2 (int nbits, int nq, int nb,
665
- const uint32_t *qcodes, const uint32_t *bcodes,
666
- const float *gt_distances):
667
- nbits(nbits), nq(nq), nb(nb), qcodes(qcodes),
668
- bcodes(bcodes), gt_distances(gt_distances)
669
- {
668
+ const float* gt_distances;
669
+
670
+ RankingScore2(
671
+ int nbits,
672
+ int nq,
673
+ int nb,
674
+ const uint32_t* qcodes,
675
+ const uint32_t* bcodes,
676
+ const float* gt_distances)
677
+ : nbits(nbits),
678
+ nq(nq),
679
+ nb(nb),
680
+ qcodes(qcodes),
681
+ bcodes(bcodes),
682
+ gt_distances(gt_distances) {
670
683
  n = nc = 1 << nbits;
671
- n_gt.resize (nc * nc * nc);
672
- init_n_gt ();
684
+ n_gt.resize(nc * nc * nc);
685
+ init_n_gt();
673
686
  }
674
687
 
675
-
676
- double rank_weight (int r)
677
- {
688
+ double rank_weight(int r) {
678
689
  return 1.0 / (r + 1);
679
690
  }
680
691
 
@@ -683,271 +694,290 @@ struct RankingScore2: Score3Computer<float, double> {
683
694
  /// they are the ranks of j and k respectively.
684
695
  /// specific version for diff-of-rank weighting, cannot optimized
685
696
  /// with a cumulative table
686
- double accum_gt_weight_diff (const std::vector<int> & a,
687
- const std::vector<int> & b)
688
- {
697
+ double accum_gt_weight_diff(
698
+ const std::vector<int>& a,
699
+ const std::vector<int>& b) {
689
700
  int nb = b.size(), na = a.size();
690
701
 
691
702
  double accu = 0;
692
703
  int j = 0;
693
704
  for (int i = 0; i < na; i++) {
694
705
  int ai = a[i];
695
- while (j < nb && ai >= b[j]) j++;
706
+ while (j < nb && ai >= b[j])
707
+ j++;
696
708
 
697
709
  double accu_i = 0;
698
710
  for (int k = j; k < b.size(); k++)
699
- accu_i += rank_weight (b[k] - ai);
700
-
701
- accu += rank_weight (ai) * accu_i;
711
+ accu_i += rank_weight(b[k] - ai);
702
712
 
713
+ accu += rank_weight(ai) * accu_i;
703
714
  }
704
715
  return accu;
705
716
  }
706
717
 
707
- void init_n_gt ()
708
- {
718
+ void init_n_gt() {
709
719
  for (int q = 0; q < nq; q++) {
710
- const float *gtd = gt_distances + q * nb;
711
- const uint32_t *cb = bcodes;// all same codes
712
- float * n_gt_q = & n_gt [qcodes[q] * nc * nc];
720
+ const float* gtd = gt_distances + q * nb;
721
+ const uint32_t* cb = bcodes; // all same codes
722
+ float* n_gt_q = &n_gt[qcodes[q] * nc * nc];
713
723
 
714
- printf("init gt for q=%d/%d \r", q, nq); fflush(stdout);
724
+ printf("init gt for q=%d/%d \r", q, nq);
725
+ fflush(stdout);
715
726
 
716
- std::vector<int> rankv (nb);
717
- int * ranks = rankv.data();
727
+ std::vector<int> rankv(nb);
728
+ int* ranks = rankv.data();
718
729
 
719
730
  // elements in each code bin, ordered by rank within each bin
720
- std::vector<std::vector<int> > tab (nc);
731
+ std::vector<std::vector<int>> tab(nc);
721
732
 
722
733
  { // build rank table
723
734
  IndirectSort s = {gtd};
724
- for (int j = 0; j < nb; j++) ranks[j] = j;
725
- std::sort (ranks, ranks + nb, s);
735
+ for (int j = 0; j < nb; j++)
736
+ ranks[j] = j;
737
+ std::sort(ranks, ranks + nb, s);
726
738
  }
727
739
 
728
740
  for (int rank = 0; rank < nb; rank++) {
729
- int i = ranks [rank];
730
- tab [cb[i]].push_back (rank);
741
+ int i = ranks[rank];
742
+ tab[cb[i]].push_back(rank);
731
743
  }
732
744
 
733
-
734
745
  // this is very expensive. Any suggestion for improvement
735
746
  // welcome.
736
747
  for (int i = 0; i < nc; i++) {
737
- std::vector<int> & di = tab[i];
748
+ std::vector<int>& di = tab[i];
738
749
  for (int j = 0; j < nc; j++) {
739
- std::vector<int> & dj = tab[j];
740
- n_gt_q [i * nc + j] += accum_gt_weight_diff (di, dj);
741
-
750
+ std::vector<int>& dj = tab[j];
751
+ n_gt_q[i * nc + j] += accum_gt_weight_diff(di, dj);
742
752
  }
743
753
  }
744
-
745
754
  }
746
-
747
755
  }
748
-
749
756
  };
750
757
 
751
-
752
758
  /*****************************************
753
759
  * PolysemousTraining
754
760
  ******************************************/
755
761
 
756
-
757
-
758
- PolysemousTraining::PolysemousTraining ()
759
- {
762
+ PolysemousTraining::PolysemousTraining() {
760
763
  optimization_type = OT_ReproduceDistances_affine;
761
764
  ntrain_permutation = 0;
762
765
  dis_weight_factor = log(2);
766
+ // max 20 G RAM
767
+ max_memory = (size_t)(20) * 1024 * 1024 * 1024;
763
768
  }
764
769
 
765
-
766
-
767
- void PolysemousTraining::optimize_reproduce_distances (
768
- ProductQuantizer &pq) const
769
- {
770
-
770
+ void PolysemousTraining::optimize_reproduce_distances(
771
+ ProductQuantizer& pq) const {
771
772
  int dsub = pq.dsub;
772
773
 
773
774
  int n = pq.ksub;
774
775
  int nbits = pq.nbits;
775
776
 
776
- #pragma omp parallel for
777
+ size_t mem1 = memory_usage_per_thread(pq);
778
+ int nt = std::min(omp_get_max_threads(), int(pq.M));
779
+ FAISS_THROW_IF_NOT_FMT(
780
+ mem1 < max_memory,
781
+ "Polysemous training will use %zd bytes per thread, while the max is set to %zd",
782
+ mem1,
783
+ max_memory);
784
+
785
+ if (mem1 * nt > max_memory) {
786
+ nt = max_memory / mem1;
787
+ fprintf(stderr,
788
+ "Polysemous training: WARN, reducing number of threads to %d to save memory",
789
+ nt);
790
+ }
791
+
792
+ #pragma omp parallel for num_threads(nt)
777
793
  for (int m = 0; m < pq.M; m++) {
778
794
  std::vector<double> dis_table;
779
795
 
780
796
  // printf ("Optimizing quantizer %d\n", m);
781
797
 
782
- float * centroids = pq.get_centroids (m, 0);
798
+ float* centroids = pq.get_centroids(m, 0);
783
799
 
784
800
  for (int i = 0; i < n; i++) {
785
801
  for (int j = 0; j < n; j++) {
786
- dis_table.push_back (fvec_L2sqr (centroids + i * dsub,
787
- centroids + j * dsub,
788
- dsub));
802
+ dis_table.push_back(fvec_L2sqr(
803
+ centroids + i * dsub, centroids + j * dsub, dsub));
789
804
  }
790
805
  }
791
806
 
792
- std::vector<int> perm (n);
793
- ReproduceWithHammingObjective obj (
794
- nbits, dis_table,
795
- dis_weight_factor);
796
-
807
+ std::vector<int> perm(n);
808
+ ReproduceWithHammingObjective obj(nbits, dis_table, dis_weight_factor);
797
809
 
798
- SimulatedAnnealingOptimizer optim (&obj, *this);
810
+ SimulatedAnnealingOptimizer optim(&obj, *this);
799
811
 
800
812
  if (log_pattern.size()) {
801
813
  char fname[256];
802
- snprintf (fname, 256, log_pattern.c_str(), m);
803
- printf ("opening log file %s\n", fname);
804
- optim.logfile = fopen (fname, "w");
805
- FAISS_THROW_IF_NOT_MSG (optim.logfile, "could not open logfile");
814
+ snprintf(fname, 256, log_pattern.c_str(), m);
815
+ printf("opening log file %s\n", fname);
816
+ optim.logfile = fopen(fname, "w");
817
+ FAISS_THROW_IF_NOT_MSG(optim.logfile, "could not open logfile");
806
818
  }
807
- double final_cost = optim.run_optimization (perm.data());
819
+ double final_cost = optim.run_optimization(perm.data());
808
820
 
809
821
  if (verbose > 0) {
810
- printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
811
- m, optim.init_cost, final_cost);
822
+ printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
823
+ m,
824
+ optim.init_cost,
825
+ final_cost);
812
826
  }
813
827
 
814
- if (log_pattern.size()) fclose (optim.logfile);
828
+ if (log_pattern.size())
829
+ fclose(optim.logfile);
815
830
 
816
831
  std::vector<float> centroids_copy;
817
832
  for (int i = 0; i < dsub * n; i++)
818
- centroids_copy.push_back (centroids[i]);
833
+ centroids_copy.push_back(centroids[i]);
819
834
 
820
835
  for (int i = 0; i < n; i++)
821
- memcpy (centroids + perm[i] * dsub,
822
- centroids_copy.data() + i * dsub,
823
- dsub * sizeof(centroids[0]));
824
-
836
+ memcpy(centroids + perm[i] * dsub,
837
+ centroids_copy.data() + i * dsub,
838
+ dsub * sizeof(centroids[0]));
825
839
  }
826
-
827
840
  }
828
841
 
829
-
830
- void PolysemousTraining::optimize_ranking (
831
- ProductQuantizer &pq, size_t n, const float *x) const
832
- {
833
-
842
+ void PolysemousTraining::optimize_ranking(
843
+ ProductQuantizer& pq,
844
+ size_t n,
845
+ const float* x) const {
834
846
  int dsub = pq.dsub;
835
-
836
847
  int nbits = pq.nbits;
837
848
 
838
- std::vector<uint8_t> all_codes (pq.code_size * n);
849
+ std::vector<uint8_t> all_codes(pq.code_size * n);
839
850
 
840
- pq.compute_codes (x, all_codes.data(), n);
851
+ pq.compute_codes(x, all_codes.data(), n);
841
852
 
842
- FAISS_THROW_IF_NOT (pq.nbits == 8);
853
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
843
854
 
844
- if (n == 0)
845
- pq.compute_sdc_table ();
855
+ if (n == 0) {
856
+ pq.compute_sdc_table();
857
+ }
846
858
 
847
859
  #pragma omp parallel for
848
860
  for (int m = 0; m < pq.M; m++) {
849
861
  size_t nq, nb;
850
- std::vector <uint32_t> codes; // query codes, then db codes
851
- std::vector <float> gt_distances; // nq * nb matrix of distances
862
+ std::vector<uint32_t> codes; // query codes, then db codes
863
+ std::vector<float> gt_distances; // nq * nb matrix of distances
852
864
 
853
865
  if (n > 0) {
854
- std::vector<float> xtrain (n * dsub);
866
+ std::vector<float> xtrain(n * dsub);
855
867
  for (int i = 0; i < n; i++)
856
- memcpy (xtrain.data() + i * dsub,
857
- x + i * pq.d + m * dsub,
858
- sizeof(float) * dsub);
868
+ memcpy(xtrain.data() + i * dsub,
869
+ x + i * pq.d + m * dsub,
870
+ sizeof(float) * dsub);
859
871
 
860
- codes.resize (n);
872
+ codes.resize(n);
861
873
  for (int i = 0; i < n; i++)
862
- codes [i] = all_codes [i * pq.code_size + m];
874
+ codes[i] = all_codes[i * pq.code_size + m];
863
875
 
864
- nq = n / 4; nb = n - nq;
865
- const float *xq = xtrain.data();
866
- const float *xb = xq + nq * dsub;
876
+ nq = n / 4;
877
+ nb = n - nq;
878
+ const float* xq = xtrain.data();
879
+ const float* xb = xq + nq * dsub;
867
880
 
868
- gt_distances.resize (nq * nb);
881
+ gt_distances.resize(nq * nb);
869
882
 
870
- pairwise_L2sqr (dsub,
871
- nq, xq,
872
- nb, xb,
873
- gt_distances.data());
883
+ pairwise_L2sqr(dsub, nq, xq, nb, xb, gt_distances.data());
874
884
  } else {
875
885
  nq = nb = pq.ksub;
876
- codes.resize (2 * nq);
886
+ codes.resize(2 * nq);
877
887
  for (int i = 0; i < nq; i++)
878
- codes[i] = codes [i + nq] = i;
888
+ codes[i] = codes[i + nq] = i;
879
889
 
880
- gt_distances.resize (nq * nb);
890
+ gt_distances.resize(nq * nb);
881
891
 
882
- memcpy (gt_distances.data (),
883
- pq.sdc_table.data () + m * nq * nb,
884
- sizeof (float) * nq * nb);
892
+ memcpy(gt_distances.data(),
893
+ pq.sdc_table.data() + m * nq * nb,
894
+ sizeof(float) * nq * nb);
885
895
  }
886
896
 
887
- double t0 = getmillisecs ();
897
+ double t0 = getmillisecs();
888
898
 
889
- PermutationObjective *obj = new RankingScore2 (
890
- nbits, nq, nb,
891
- codes.data(), codes.data() + nq,
892
- gt_distances.data ());
893
- ScopeDeleter1<PermutationObjective> del (obj);
899
+ PermutationObjective* obj = new RankingScore2(
900
+ nbits,
901
+ nq,
902
+ nb,
903
+ codes.data(),
904
+ codes.data() + nq,
905
+ gt_distances.data());
906
+ ScopeDeleter1<PermutationObjective> del(obj);
894
907
 
895
908
  if (verbose > 0) {
896
909
  printf(" m=%d, nq=%zd, nb=%zd, intialize RankingScore "
897
910
  "in %.3f ms\n",
898
- m, nq, nb, getmillisecs () - t0);
911
+ m,
912
+ nq,
913
+ nb,
914
+ getmillisecs() - t0);
899
915
  }
900
916
 
901
- SimulatedAnnealingOptimizer optim (obj, *this);
917
+ SimulatedAnnealingOptimizer optim(obj, *this);
902
918
 
903
919
  if (log_pattern.size()) {
904
920
  char fname[256];
905
- snprintf (fname, 256, log_pattern.c_str(), m);
906
- printf ("opening log file %s\n", fname);
907
- optim.logfile = fopen (fname, "w");
908
- FAISS_THROW_IF_NOT_FMT (optim.logfile,
909
- "could not open logfile %s", fname);
921
+ snprintf(fname, 256, log_pattern.c_str(), m);
922
+ printf("opening log file %s\n", fname);
923
+ optim.logfile = fopen(fname, "w");
924
+ FAISS_THROW_IF_NOT_FMT(
925
+ optim.logfile, "could not open logfile %s", fname);
910
926
  }
911
927
 
912
- std::vector<int> perm (pq.ksub);
928
+ std::vector<int> perm(pq.ksub);
913
929
 
914
- double final_cost = optim.run_optimization (perm.data());
915
- printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
916
- m, optim.init_cost, final_cost);
930
+ double final_cost = optim.run_optimization(perm.data());
931
+ printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
932
+ m,
933
+ optim.init_cost,
934
+ final_cost);
917
935
 
918
- if (log_pattern.size()) fclose (optim.logfile);
936
+ if (log_pattern.size())
937
+ fclose(optim.logfile);
919
938
 
920
- float * centroids = pq.get_centroids (m, 0);
939
+ float* centroids = pq.get_centroids(m, 0);
921
940
 
922
941
  std::vector<float> centroids_copy;
923
942
  for (int i = 0; i < dsub * pq.ksub; i++)
924
- centroids_copy.push_back (centroids[i]);
943
+ centroids_copy.push_back(centroids[i]);
925
944
 
926
945
  for (int i = 0; i < pq.ksub; i++)
927
- memcpy (centroids + perm[i] * dsub,
928
- centroids_copy.data() + i * dsub,
929
- dsub * sizeof(centroids[0]));
930
-
946
+ memcpy(centroids + perm[i] * dsub,
947
+ centroids_copy.data() + i * dsub,
948
+ dsub * sizeof(centroids[0]));
931
949
  }
932
-
933
950
  }
934
951
 
935
-
936
-
937
- void PolysemousTraining::optimize_pq_for_hamming (ProductQuantizer &pq,
938
- size_t n, const float *x) const
939
- {
952
+ void PolysemousTraining::optimize_pq_for_hamming(
953
+ ProductQuantizer& pq,
954
+ size_t n,
955
+ const float* x) const {
940
956
  if (optimization_type == OT_None) {
941
-
942
957
  } else if (optimization_type == OT_ReproduceDistances_affine) {
943
- optimize_reproduce_distances (pq);
958
+ optimize_reproduce_distances(pq);
944
959
  } else {
945
- optimize_ranking (pq, n, x);
960
+ optimize_ranking(pq, n, x);
946
961
  }
947
962
 
948
- pq.compute_sdc_table ();
949
-
963
+ pq.compute_sdc_table();
950
964
  }
951
965
 
966
+ size_t PolysemousTraining::memory_usage_per_thread(
967
+ const ProductQuantizer& pq) const {
968
+ size_t n = pq.ksub;
969
+
970
+ switch (optimization_type) {
971
+ case OT_None:
972
+ return 0;
973
+ case OT_ReproduceDistances_affine:
974
+ return n * n * sizeof(double) * 3;
975
+ case OT_Ranking_weighted_diff:
976
+ return n * n * n * sizeof(float);
977
+ }
978
+
979
+ FAISS_THROW_MSG("Invalid optmization type");
980
+ return 0;
981
+ }
952
982
 
953
983
  } // namespace faiss