faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,199 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <memory>
13
+ #include <mutex>
14
+ #include <vector>
15
+
16
+ #include <omp.h>
17
+
18
+ #include <faiss/Index.h>
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/utils/Heap.h>
22
+ #include <faiss/utils/random.h>
23
+
24
+ namespace faiss {
25
+
26
+ /** Implementation of the Navigating Spreading-out Graph (NSG)
27
+ * datastructure.
28
+ *
29
+ * Fast Approximate Nearest Neighbor Search With The
30
+ * Navigating Spreading-out Graph
31
+ *
32
+ * Cong Fu, Chao Xiang, Changxu Wang, Deng Cai, VLDB 2019
33
+ *
34
+ * This implementation is heavily influenced by the NSG
35
+ * implementation by ZJULearning Group
36
+ * (https://github.com/zjulearning/nsg)
37
+ *
38
+ * The NSG object stores only the neighbor link structure, see
39
+ * IndexNSG.h for the full index object.
40
+ */
41
+
42
+ struct DistanceComputer; // from AuxIndexStructures
43
+ struct Neighbor;
44
+ struct Node;
45
+
46
+ namespace nsg {
47
+
48
+ /***********************************************************
49
+ * Graph structure to store a graph.
50
+ *
51
+ * It is represented by an adjacency matrix `data`, where
52
+ * data[i, j] is the j-th neighbor of node i.
53
+ ***********************************************************/
54
+
55
+ template <class node_t>
56
+ struct Graph {
57
+ node_t* data; ///< the flattened adjacency matrix
58
+ int K; ///< nb of neighbors per node
59
+ int N; ///< total nb of nodes
60
+ bool own_fields; ///< the underlying data owned by itself or not
61
+
62
+ // construct from a known graph
63
+ Graph(node_t* data, int N, int K)
64
+ : data(data), K(K), N(N), own_fields(false) {}
65
+
66
+ // construct an empty graph
67
+ // NOTE: the newly allocated data needs to be destroyed at destruction time
68
+ Graph(int N, int K) : K(K), N(N), own_fields(true) {
69
+ data = new node_t[N * K];
70
+ }
71
+
72
+ // copy constructor
73
+ Graph(const Graph& g) : Graph(g.N, g.K) {
74
+ memcpy(data, g.data, N * K * sizeof(node_t));
75
+ }
76
+
77
+ // release the allocated memory if needed
78
+ ~Graph() {
79
+ if (own_fields) {
80
+ delete[] data;
81
+ }
82
+ }
83
+
84
+ // access the j-th neighbor of node i
85
+ inline node_t at(int i, int j) const {
86
+ return data[i * K + j];
87
+ }
88
+
89
+ // access the j-th neighbor of node i by reference
90
+ inline node_t& at(int i, int j) {
91
+ return data[i * K + j];
92
+ }
93
+ };
94
+
95
+ DistanceComputer* storage_distance_computer(const Index* storage);
96
+
97
+ } // namespace nsg
98
+
99
+ struct NSG {
100
+ /// internal storage of vectors (32 bits: this is expensive)
101
+ using storage_idx_t = int;
102
+
103
+ /// Faiss results are 64-bit
104
+ using idx_t = Index::idx_t;
105
+
106
+ int ntotal; ///< nb of nodes
107
+
108
+ /// construction-time parameters
109
+ int R; ///< nb of neighbors per node
110
+ int L; ///< length of the search path at construction time
111
+ int C; ///< candidate pool size at construction time
112
+
113
+ // search-time parameters
114
+ int search_L; ///< length of the search path
115
+
116
+ int enterpoint; ///< enterpoint
117
+
118
+ std::shared_ptr<nsg::Graph<int>> final_graph; ///< NSG graph structure
119
+
120
+ bool is_built; ///< NSG is built or not
121
+
122
+ RandomGenerator rng; ///< random generator
123
+
124
+ explicit NSG(int R = 32);
125
+
126
+ // build NSG from a KNN graph
127
+ void build(
128
+ Index* storage,
129
+ idx_t n,
130
+ const nsg::Graph<idx_t>& knn_graph,
131
+ bool verbose);
132
+
133
+ // reset the graph
134
+ void reset();
135
+
136
+ // search interface
137
+ void search(
138
+ DistanceComputer& dis,
139
+ int k,
140
+ idx_t* I,
141
+ float* D,
142
+ VisitedTable& vt) const;
143
+
144
+ // Compute the center point
145
+ void init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph);
146
+
147
+ // Search on a built graph.
148
+ // If collect_fullset is true, the visited nodes will be
149
+ // collected in `fullset`.
150
+ template <bool collect_fullset, class index_t>
151
+ void search_on_graph(
152
+ const nsg::Graph<index_t>& graph,
153
+ DistanceComputer& dis,
154
+ VisitedTable& vt,
155
+ int ep,
156
+ int pool_size,
157
+ std::vector<Neighbor>& retset,
158
+ std::vector<Node>& fullset) const;
159
+
160
+ // Add reverse links
161
+ void add_reverse_links(
162
+ int q,
163
+ std::vector<std::mutex>& locks,
164
+ DistanceComputer& dis,
165
+ nsg::Graph<Node>& graph);
166
+
167
+ void sync_prune(
168
+ int q,
169
+ std::vector<Node>& pool,
170
+ DistanceComputer& dis,
171
+ VisitedTable& vt,
172
+ const nsg::Graph<idx_t>& knn_graph,
173
+ nsg::Graph<Node>& graph);
174
+
175
+ void link(
176
+ Index* storage,
177
+ const nsg::Graph<idx_t>& knn_graph,
178
+ nsg::Graph<Node>& graph,
179
+ bool verbose);
180
+
181
+ // make NSG be fully connected
182
+ int tree_grow(Index* storage, std::vector<int>& degrees);
183
+
184
+ // count the size of the connected component
185
+ // using depth first search start by root
186
+ int dfs(VisitedTable& vt, int root, int cnt) const;
187
+
188
+ // attach one unlinked node
189
+ int attach_unlinked(
190
+ Index* storage,
191
+ VisitedTable& vt,
192
+ VisitedTable& vt2,
193
+ std::vector<int>& degrees);
194
+
195
+ // check the integrity of the NSG built
196
+ void check_graph() const;
197
+ };
198
+
199
+ } // namespace faiss
@@ -8,18 +8,21 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/impl/PolysemousTraining.h>
11
+ #include "faiss/impl/FaissAssert.h"
12
+
13
+ #include <omp.h>
14
+ #include <stdint.h>
11
15
 
12
- #include <cstdlib>
13
16
  #include <cmath>
17
+ #include <cstdlib>
14
18
  #include <cstring>
15
- #include <stdint.h>
16
19
 
17
20
  #include <algorithm>
18
21
 
19
- #include <faiss/utils/random.h>
20
- #include <faiss/utils/utils.h>
21
22
  #include <faiss/utils/distances.h>
22
23
  #include <faiss/utils/hamming.h>
24
+ #include <faiss/utils/random.h>
25
+ #include <faiss/utils/utils.h>
23
26
 
24
27
  #include <faiss/impl/FaissAssert.h>
25
28
 
@@ -29,16 +32,14 @@
29
32
 
30
33
  namespace faiss {
31
34
 
32
-
33
35
  /****************************************************
34
36
  * Optimization code
35
37
  ****************************************************/
36
38
 
37
- SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
38
- {
39
+ SimulatedAnnealingParameters::SimulatedAnnealingParameters() {
39
40
  // set some reasonable defaults for the optimization
40
41
  init_temperature = 0.7;
41
- temperature_decay = pow (0.9, 1/500.);
42
+ temperature_decay = pow(0.9, 1 / 500.);
42
43
  // reduce by a factor 0.9 every 500 it
43
44
  n_iter = 500000;
44
45
  n_redo = 2;
@@ -50,44 +51,37 @@ SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
50
51
 
51
52
  // what would the cost update be if iw and jw were swapped?
52
53
  // default implementation just computes both and computes the difference
53
- double PermutationObjective::cost_update (
54
- const int *perm, int iw, int jw) const
55
- {
56
- double orig_cost = compute_cost (perm);
54
+ double PermutationObjective::cost_update(const int* perm, int iw, int jw)
55
+ const {
56
+ double orig_cost = compute_cost(perm);
57
57
 
58
- std::vector<int> perm2 (n);
58
+ std::vector<int> perm2(n);
59
59
  for (int i = 0; i < n; i++)
60
60
  perm2[i] = perm[i];
61
61
  perm2[iw] = perm[jw];
62
62
  perm2[jw] = perm[iw];
63
63
 
64
- double new_cost = compute_cost (perm2.data());
64
+ double new_cost = compute_cost(perm2.data());
65
65
  return new_cost - orig_cost;
66
66
  }
67
67
 
68
-
69
-
70
-
71
- SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer (
72
- PermutationObjective *obj,
73
- const SimulatedAnnealingParameters &p):
74
- SimulatedAnnealingParameters (p),
75
- obj (obj),
76
- n(obj->n),
77
- logfile (nullptr)
78
- {
79
- rnd = new RandomGenerator (p.seed);
80
- FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
68
+ SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer(
69
+ PermutationObjective* obj,
70
+ const SimulatedAnnealingParameters& p)
71
+ : SimulatedAnnealingParameters(p),
72
+ obj(obj),
73
+ n(obj->n),
74
+ logfile(nullptr) {
75
+ rnd = new RandomGenerator(p.seed);
76
+ FAISS_THROW_IF_NOT(n < 100000 && n >= 0);
81
77
  }
82
78
 
83
- SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer ()
84
- {
79
+ SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer() {
85
80
  delete rnd;
86
81
  }
87
82
 
88
83
  // run the optimization and return the best result in best_perm
89
- double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
90
- {
84
+ double SimulatedAnnealingOptimizer::run_optimization(int* best_perm) {
91
85
  double min_cost = 1e30;
92
86
 
93
87
  // just do a few runs of the annealing and keep the lowest output cost
@@ -95,84 +89,89 @@ double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
95
89
  std::vector<int> perm(n);
96
90
  for (int i = 0; i < n; i++)
97
91
  perm[i] = i;
98
- if (init_random) {
92
+ if (init_random) {
99
93
  for (int i = 0; i < n; i++) {
100
- int j = i + rnd->rand_int (n - i);
101
- std::swap (perm[i], perm[j]);
94
+ int j = i + rnd->rand_int(n - i);
95
+ std::swap(perm[i], perm[j]);
102
96
  }
103
97
  }
104
- float cost = optimize (perm.data());
105
- if (logfile) fprintf (logfile, "\n");
106
- if(verbose > 1) {
107
- printf (" optimization run %d: cost=%g %s\n",
108
- it, cost, cost < min_cost ? "keep" : "");
98
+ float cost = optimize(perm.data());
99
+ if (logfile)
100
+ fprintf(logfile, "\n");
101
+ if (verbose > 1) {
102
+ printf(" optimization run %d: cost=%g %s\n",
103
+ it,
104
+ cost,
105
+ cost < min_cost ? "keep" : "");
109
106
  }
110
107
  if (cost < min_cost) {
111
- memcpy (best_perm, perm.data(), sizeof(perm[0]) * n);
108
+ memcpy(best_perm, perm.data(), sizeof(perm[0]) * n);
112
109
  min_cost = cost;
113
110
  }
114
111
  }
115
- return min_cost;
112
+ return min_cost;
116
113
  }
117
114
 
118
115
  // perform the optimization loop, starting from and modifying
119
116
  // permutation in-place
120
- double SimulatedAnnealingOptimizer::optimize (int *perm)
121
- {
122
- double cost = init_cost = obj->compute_cost (perm);
117
+ double SimulatedAnnealingOptimizer::optimize(int* perm) {
118
+ double cost = init_cost = obj->compute_cost(perm);
123
119
  int log2n = 0;
124
- while (!(n <= (1 << log2n))) log2n++;
120
+ while (!(n <= (1 << log2n)))
121
+ log2n++;
125
122
  double temperature = init_temperature;
126
- int n_swap = 0, n_hot = 0;
123
+ int n_swap = 0, n_hot = 0;
127
124
  for (int it = 0; it < n_iter; it++) {
128
125
  temperature = temperature * temperature_decay;
129
126
  int iw, jw;
130
127
  if (only_bit_flips) {
131
- iw = rnd->rand_int (n);
132
- jw = iw ^ (1 << rnd->rand_int (log2n));
128
+ iw = rnd->rand_int(n);
129
+ jw = iw ^ (1 << rnd->rand_int(log2n));
133
130
  } else {
134
- iw = rnd->rand_int (n);
135
- jw = rnd->rand_int (n - 1);
136
- if (jw == iw) jw++;
131
+ iw = rnd->rand_int(n);
132
+ jw = rnd->rand_int(n - 1);
133
+ if (jw == iw)
134
+ jw++;
137
135
  }
138
- double delta_cost = obj->cost_update (perm, iw, jw);
139
- if (delta_cost < 0 || rnd->rand_float () < temperature) {
140
- std::swap (perm[iw], perm[jw]);
136
+ double delta_cost = obj->cost_update(perm, iw, jw);
137
+ if (delta_cost < 0 || rnd->rand_float() < temperature) {
138
+ std::swap(perm[iw], perm[jw]);
141
139
  cost += delta_cost;
142
140
  n_swap++;
143
- if (delta_cost >= 0) n_hot++;
141
+ if (delta_cost >= 0)
142
+ n_hot++;
144
143
  }
145
- if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
146
- printf (" iteration %d cost %g temp %g n_swap %d "
147
- "(%d hot) \r",
148
- it, cost, temperature, n_swap, n_hot);
144
+ if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
145
+ printf(" iteration %d cost %g temp %g n_swap %d "
146
+ "(%d hot) \r",
147
+ it,
148
+ cost,
149
+ temperature,
150
+ n_swap,
151
+ n_hot);
149
152
  fflush(stdout);
150
153
  }
151
154
  if (logfile) {
152
- fprintf (logfile, "%d %g %g %d %d\n",
153
- it, cost, temperature, n_swap, n_hot);
155
+ fprintf(logfile,
156
+ "%d %g %g %d %d\n",
157
+ it,
158
+ cost,
159
+ temperature,
160
+ n_swap,
161
+ n_hot);
154
162
  }
155
- }
156
- if (verbose > 1) printf("\n");
163
+ }
164
+ if (verbose > 1)
165
+ printf("\n");
157
166
  return cost;
158
167
  }
159
168
 
160
-
161
-
162
-
163
-
164
169
  /****************************************************
165
170
  * Cost functions: ReproduceDistanceTable
166
171
  ****************************************************/
167
172
 
168
-
169
-
170
-
171
-
172
-
173
- static inline int hamming_dis (uint64_t a, uint64_t b)
174
- {
175
- return __builtin_popcountl (a ^ b);
173
+ static inline int hamming_dis(uint64_t a, uint64_t b) {
174
+ return __builtin_popcountl(a ^ b);
176
175
  }
177
176
 
178
177
  namespace {
@@ -182,14 +181,14 @@ struct ReproduceWithHammingObjective : PermutationObjective {
182
181
  int nbits;
183
182
  double dis_weight_factor;
184
183
 
185
- static double sqr (double x) { return x * x; }
186
-
184
+ static double sqr(double x) {
185
+ return x * x;
186
+ }
187
187
 
188
188
  // weihgting of distances: it is more important to reproduce small
189
189
  // distances well
190
- double dis_weight (double x) const
191
- {
192
- return exp (-dis_weight_factor * x);
190
+ double dis_weight(double x) const {
191
+ return exp(-dis_weight_factor * x);
193
192
  }
194
193
 
195
194
  std::vector<double> target_dis; // wanted distances (size n^2)
@@ -197,101 +196,105 @@ struct ReproduceWithHammingObjective : PermutationObjective {
197
196
 
198
197
  // cost = quadratic difference between actual distance and Hamming distance
199
198
  double compute_cost(const int* perm) const override {
200
- double cost = 0;
201
- for (int i = 0; i < n; i++) {
202
- for (int j = 0; j < n; j++) {
203
- double wanted = target_dis[i * n + j];
204
- double w = weights[i * n + j];
205
- double actual = hamming_dis(perm[i], perm[j]);
206
- cost += w * sqr(wanted - actual);
199
+ double cost = 0;
200
+ for (int i = 0; i < n; i++) {
201
+ for (int j = 0; j < n; j++) {
202
+ double wanted = target_dis[i * n + j];
203
+ double w = weights[i * n + j];
204
+ double actual = hamming_dis(perm[i], perm[j]);
205
+ cost += w * sqr(wanted - actual);
206
+ }
207
207
  }
208
- }
209
- return cost;
208
+ return cost;
210
209
  }
211
210
 
212
-
213
211
  // what would the cost update be if iw and jw were swapped?
214
212
  // computed in O(n) instead of O(n^2) for the full re-computation
215
213
  double cost_update(const int* perm, int iw, int jw) const override {
216
- double delta_cost = 0;
214
+ double delta_cost = 0;
217
215
 
218
- for (int i = 0; i < n; i++) {
219
- if (i == iw) {
220
- for (int j = 0; j < n; j++) {
221
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
222
- double actual = hamming_dis(perm[i], perm[j]);
223
- delta_cost -= w * sqr(wanted - actual);
224
- double new_actual =
225
- hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]);
226
- delta_cost += w * sqr(wanted - new_actual);
227
- }
228
- } else if (i == jw) {
229
- for (int j = 0; j < n; j++) {
230
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
231
- double actual = hamming_dis(perm[i], perm[j]);
232
- delta_cost -= w * sqr(wanted - actual);
233
- double new_actual =
234
- hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]);
235
- delta_cost += w * sqr(wanted - new_actual);
236
- }
237
- } else {
238
- int j = iw;
239
- {
240
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
241
- double actual = hamming_dis(perm[i], perm[j]);
242
- delta_cost -= w * sqr(wanted - actual);
243
- double new_actual = hamming_dis(perm[i], perm[jw]);
244
- delta_cost += w * sqr(wanted - new_actual);
245
- }
246
- j = jw;
247
- {
248
- double wanted = target_dis[i * n + j], w = weights[i * n + j];
249
- double actual = hamming_dis(perm[i], perm[j]);
250
- delta_cost -= w * sqr(wanted - actual);
251
- double new_actual = hamming_dis(perm[i], perm[iw]);
252
- delta_cost += w * sqr(wanted - new_actual);
253
- }
216
+ for (int i = 0; i < n; i++) {
217
+ if (i == iw) {
218
+ for (int j = 0; j < n; j++) {
219
+ double wanted = target_dis[i * n + j],
220
+ w = weights[i * n + j];
221
+ double actual = hamming_dis(perm[i], perm[j]);
222
+ delta_cost -= w * sqr(wanted - actual);
223
+ double new_actual = hamming_dis(
224
+ perm[jw],
225
+ perm[j == iw ? jw
226
+ : j == jw ? iw
227
+ : j]);
228
+ delta_cost += w * sqr(wanted - new_actual);
229
+ }
230
+ } else if (i == jw) {
231
+ for (int j = 0; j < n; j++) {
232
+ double wanted = target_dis[i * n + j],
233
+ w = weights[i * n + j];
234
+ double actual = hamming_dis(perm[i], perm[j]);
235
+ delta_cost -= w * sqr(wanted - actual);
236
+ double new_actual = hamming_dis(
237
+ perm[iw],
238
+ perm[j == iw ? jw
239
+ : j == jw ? iw
240
+ : j]);
241
+ delta_cost += w * sqr(wanted - new_actual);
242
+ }
243
+ } else {
244
+ int j = iw;
245
+ {
246
+ double wanted = target_dis[i * n + j],
247
+ w = weights[i * n + j];
248
+ double actual = hamming_dis(perm[i], perm[j]);
249
+ delta_cost -= w * sqr(wanted - actual);
250
+ double new_actual = hamming_dis(perm[i], perm[jw]);
251
+ delta_cost += w * sqr(wanted - new_actual);
252
+ }
253
+ j = jw;
254
+ {
255
+ double wanted = target_dis[i * n + j],
256
+ w = weights[i * n + j];
257
+ double actual = hamming_dis(perm[i], perm[j]);
258
+ delta_cost -= w * sqr(wanted - actual);
259
+ double new_actual = hamming_dis(perm[i], perm[iw]);
260
+ delta_cost += w * sqr(wanted - new_actual);
261
+ }
262
+ }
254
263
  }
255
- }
256
264
 
257
- return delta_cost;
265
+ return delta_cost;
258
266
  }
259
267
 
260
-
261
-
262
- ReproduceWithHammingObjective (
263
- int nbits,
264
- const std::vector<double> & dis_table,
265
- double dis_weight_factor):
266
- nbits (nbits), dis_weight_factor (dis_weight_factor)
267
- {
268
+ ReproduceWithHammingObjective(
269
+ int nbits,
270
+ const std::vector<double>& dis_table,
271
+ double dis_weight_factor)
272
+ : nbits(nbits), dis_weight_factor(dis_weight_factor) {
268
273
  n = 1 << nbits;
269
- FAISS_THROW_IF_NOT (dis_table.size() == n * n);
270
- set_affine_target_dis (dis_table);
274
+ FAISS_THROW_IF_NOT(dis_table.size() == n * n);
275
+ set_affine_target_dis(dis_table);
271
276
  }
272
277
 
273
- void set_affine_target_dis (const std::vector<double> & dis_table)
274
- {
278
+ void set_affine_target_dis(const std::vector<double>& dis_table) {
275
279
  double sum = 0, sum2 = 0;
276
280
  int n2 = n * n;
277
281
  for (int i = 0; i < n2; i++) {
278
- sum += dis_table [i];
279
- sum2 += dis_table [i] * dis_table [i];
282
+ sum += dis_table[i];
283
+ sum2 += dis_table[i] * dis_table[i];
280
284
  }
281
285
  double mean = sum / n2;
282
286
  double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
283
287
 
284
- target_dis.resize (n2);
288
+ target_dis.resize(n2);
285
289
 
286
290
  for (int i = 0; i < n2; i++) {
287
291
  // the mapping function
288
- double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) +
289
- nbits / 2;
292
+ double td = (dis_table[i] - mean) / stddev * sqrt(nbits / 4) +
293
+ nbits / 2;
290
294
  target_dis[i] = td;
291
295
  // compute a weight
292
- weights.push_back (dis_weight (td));
296
+ weights.push_back(dis_weight(td));
293
297
  }
294
-
295
298
  }
296
299
 
297
300
  ~ReproduceWithHammingObjective() override {}
@@ -301,27 +304,23 @@ struct ReproduceWithHammingObjective : PermutationObjective {
301
304
 
302
305
  // weihgting of distances: it is more important to reproduce small
303
306
  // distances well
304
- double ReproduceDistancesObjective::dis_weight (double x) const
305
- {
306
- return exp (-dis_weight_factor * x);
307
+ double ReproduceDistancesObjective::dis_weight(double x) const {
308
+ return exp(-dis_weight_factor * x);
307
309
  }
308
310
 
309
-
310
- double ReproduceDistancesObjective::get_source_dis (int i, int j) const
311
- {
312
- return source_dis [i * n + j];
311
+ double ReproduceDistancesObjective::get_source_dis(int i, int j) const {
312
+ return source_dis[i * n + j];
313
313
  }
314
314
 
315
315
  // cost = quadratic difference between actual distance and Hamming distance
316
- double ReproduceDistancesObjective::compute_cost (const int *perm) const
317
- {
316
+ double ReproduceDistancesObjective::compute_cost(const int* perm) const {
318
317
  double cost = 0;
319
318
  for (int i = 0; i < n; i++) {
320
319
  for (int j = 0; j < n; j++) {
321
- double wanted = target_dis [i * n + j];
322
- double w = weights [i * n + j];
323
- double actual = get_source_dis (perm[i], perm[j]);
324
- cost += w * sqr (wanted - actual);
320
+ double wanted = target_dis[i * n + j];
321
+ double w = weights[i * n + j];
322
+ double actual = get_source_dis(perm[i], perm[j]);
323
+ cost += w * sqr(wanted - actual);
325
324
  }
326
325
  }
327
326
  return cost;
@@ -329,79 +328,75 @@ double ReproduceDistancesObjective::compute_cost (const int *perm) const
329
328
 
330
329
  // what would the cost update be if iw and jw were swapped?
331
330
  // computed in O(n) instead of O(n^2) for the full re-computation
332
- double ReproduceDistancesObjective::cost_update(
333
- const int *perm, int iw, int jw) const
334
- {
331
+ double ReproduceDistancesObjective::cost_update(const int* perm, int iw, int jw)
332
+ const {
335
333
  double delta_cost = 0;
336
- for (int i = 0; i < n; i++) {
334
+ for (int i = 0; i < n; i++) {
337
335
  if (i == iw) {
338
336
  for (int j = 0; j < n; j++) {
339
- double wanted = target_dis [i * n + j],
340
- w = weights [i * n + j];
341
- double actual = get_source_dis (perm[i], perm[j]);
342
- delta_cost -= w * sqr (wanted - actual);
343
- double new_actual = get_source_dis (
344
- perm[jw],
345
- perm[j == iw ? jw : j == jw ? iw : j]);
346
- delta_cost += w * sqr (wanted - new_actual);
337
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
338
+ double actual = get_source_dis(perm[i], perm[j]);
339
+ delta_cost -= w * sqr(wanted - actual);
340
+ double new_actual = get_source_dis(
341
+ perm[jw],
342
+ perm[j == iw ? jw
343
+ : j == jw ? iw
344
+ : j]);
345
+ delta_cost += w * sqr(wanted - new_actual);
347
346
  }
348
347
  } else if (i == jw) {
349
348
  for (int j = 0; j < n; j++) {
350
- double wanted = target_dis [i * n + j],
351
- w = weights [i * n + j];
352
- double actual = get_source_dis (perm[i], perm[j]);
353
- delta_cost -= w * sqr (wanted - actual);
354
- double new_actual = get_source_dis (
355
- perm[iw],
356
- perm[j == iw ? jw : j == jw ? iw : j]);
357
- delta_cost += w * sqr (wanted - new_actual);
349
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
350
+ double actual = get_source_dis(perm[i], perm[j]);
351
+ delta_cost -= w * sqr(wanted - actual);
352
+ double new_actual = get_source_dis(
353
+ perm[iw],
354
+ perm[j == iw ? jw
355
+ : j == jw ? iw
356
+ : j]);
357
+ delta_cost += w * sqr(wanted - new_actual);
358
358
  }
359
- } else {
359
+ } else {
360
360
  int j = iw;
361
361
  {
362
- double wanted = target_dis [i * n + j],
363
- w = weights [i * n + j];
364
- double actual = get_source_dis (perm[i], perm[j]);
365
- delta_cost -= w * sqr (wanted - actual);
366
- double new_actual = get_source_dis (perm[i], perm[jw]);
367
- delta_cost += w * sqr (wanted - new_actual);
362
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
363
+ double actual = get_source_dis(perm[i], perm[j]);
364
+ delta_cost -= w * sqr(wanted - actual);
365
+ double new_actual = get_source_dis(perm[i], perm[jw]);
366
+ delta_cost += w * sqr(wanted - new_actual);
368
367
  }
369
368
  j = jw;
370
369
  {
371
- double wanted = target_dis [i * n + j],
372
- w = weights [i * n + j];
373
- double actual = get_source_dis (perm[i], perm[j]);
374
- delta_cost -= w * sqr (wanted - actual);
375
- double new_actual = get_source_dis (perm[i], perm[iw]);
376
- delta_cost += w * sqr (wanted - new_actual);
370
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
371
+ double actual = get_source_dis(perm[i], perm[j]);
372
+ delta_cost -= w * sqr(wanted - actual);
373
+ double new_actual = get_source_dis(perm[i], perm[iw]);
374
+ delta_cost += w * sqr(wanted - new_actual);
377
375
  }
378
376
  }
379
377
  }
380
- return delta_cost;
378
+ return delta_cost;
381
379
  }
382
380
 
383
-
384
-
385
- ReproduceDistancesObjective::ReproduceDistancesObjective (
386
- int n,
387
- const double *source_dis_in,
388
- const double *target_dis_in,
389
- double dis_weight_factor):
390
- dis_weight_factor (dis_weight_factor),
391
- target_dis (target_dis_in)
392
- {
381
+ ReproduceDistancesObjective::ReproduceDistancesObjective(
382
+ int n,
383
+ const double* source_dis_in,
384
+ const double* target_dis_in,
385
+ double dis_weight_factor)
386
+ : dis_weight_factor(dis_weight_factor), target_dis(target_dis_in) {
393
387
  this->n = n;
394
- set_affine_target_dis (source_dis_in);
388
+ set_affine_target_dis(source_dis_in);
395
389
  }
396
390
 
397
- void ReproduceDistancesObjective::compute_mean_stdev (
398
- const double *tab, size_t n2,
399
- double *mean_out, double *stddev_out)
400
- {
391
+ void ReproduceDistancesObjective::compute_mean_stdev(
392
+ const double* tab,
393
+ size_t n2,
394
+ double* mean_out,
395
+ double* stddev_out) {
401
396
  double sum = 0, sum2 = 0;
402
397
  for (int i = 0; i < n2; i++) {
403
- sum += tab [i];
404
- sum2 += tab [i] * tab [i];
398
+ sum += tab[i];
399
+ sum2 += tab[i] * tab[i];
405
400
  }
406
401
  double mean = sum / n2;
407
402
  double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
@@ -409,32 +404,34 @@ void ReproduceDistancesObjective::compute_mean_stdev (
409
404
  *stddev_out = stddev;
410
405
  }
411
406
 
412
- void ReproduceDistancesObjective::set_affine_target_dis (
413
- const double *source_dis_in)
414
- {
407
+ void ReproduceDistancesObjective::set_affine_target_dis(
408
+ const double* source_dis_in) {
415
409
  int n2 = n * n;
416
410
 
417
411
  double mean_src, stddev_src;
418
- compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src);
412
+ compute_mean_stdev(source_dis_in, n2, &mean_src, &stddev_src);
419
413
 
420
414
  double mean_target, stddev_target;
421
- compute_mean_stdev (target_dis, n2, &mean_target, &stddev_target);
415
+ compute_mean_stdev(target_dis, n2, &mean_target, &stddev_target);
422
416
 
423
- printf ("map mean %g std %g -> mean %g std %g\n",
424
- mean_src, stddev_src, mean_target, stddev_target);
417
+ printf("map mean %g std %g -> mean %g std %g\n",
418
+ mean_src,
419
+ stddev_src,
420
+ mean_target,
421
+ stddev_target);
425
422
 
426
- source_dis.resize (n2);
427
- weights.resize (n2);
423
+ source_dis.resize(n2);
424
+ weights.resize(n2);
428
425
 
429
426
  for (int i = 0; i < n2; i++) {
430
427
  // the mapping function
431
- source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src
432
- * stddev_target + mean_target;
428
+ source_dis[i] =
429
+ (source_dis_in[i] - mean_src) / stddev_src * stddev_target +
430
+ mean_target;
433
431
 
434
432
  // compute a weight
435
- weights [i] = dis_weight (target_dis[i]);
433
+ weights[i] = dis_weight(target_dis[i]);
436
434
  }
437
-
438
435
  }
439
436
 
440
437
  /****************************************************
@@ -444,8 +441,7 @@ void ReproduceDistancesObjective::set_affine_target_dis (
444
441
  /// Maintains a 3D table of elementary costs.
445
442
  /// Accumulates elements based on Hamming distance comparisons
446
443
  template <typename Ttab, typename Taccu>
447
- struct Score3Computer: PermutationObjective {
448
-
444
+ struct Score3Computer : PermutationObjective {
449
445
  int nc;
450
446
 
451
447
  // cost matrix of size nc * nc *nc
@@ -453,21 +449,18 @@ struct Score3Computer: PermutationObjective {
453
449
  // where x has PQ code i, y- PQ code j and y+ PQ code k
454
450
  std::vector<Ttab> n_gt;
455
451
 
456
-
457
452
  /// the cost is a triple loop on the nc * nc * nc matrix of entries.
458
453
  ///
459
- Taccu compute (const int * perm) const
460
- {
454
+ Taccu compute(const int* perm) const {
461
455
  Taccu accu = 0;
462
- const Ttab *p = n_gt.data();
456
+ const Ttab* p = n_gt.data();
463
457
  for (int i = 0; i < nc; i++) {
464
- int ip = perm [i];
458
+ int ip = perm[i];
465
459
  for (int j = 0; j < nc; j++) {
466
- int jp = perm [j];
460
+ int jp = perm[j];
467
461
  for (int k = 0; k < nc; k++) {
468
- int kp = perm [k];
469
- if (hamming_dis (ip, jp) <
470
- hamming_dis (ip, kp)) {
462
+ int kp = perm[k];
463
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
471
464
  accu += *p; // n_gt [ ( i * nc + j) * nc + k];
472
465
  }
473
466
  p++;
@@ -477,7 +470,6 @@ struct Score3Computer: PermutationObjective {
477
470
  return accu;
478
471
  }
479
472
 
480
-
481
473
  /** cost update if entries iw and jw of the permutation would be
482
474
  * swapped.
483
475
  *
@@ -487,25 +479,23 @@ struct Score3Computer: PermutationObjective {
487
479
  * cells. Practical speedup is about 8x, and the code is quite
488
480
  * complex :-/
489
481
  */
490
- Taccu compute_update (const int *perm, int iw, int jw) const
491
- {
492
- assert (iw != jw);
493
- if (iw > jw) std::swap (iw, jw);
482
+ Taccu compute_update(const int* perm, int iw, int jw) const {
483
+ assert(iw != jw);
484
+ if (iw > jw)
485
+ std::swap(iw, jw);
494
486
 
495
487
  Taccu accu = 0;
496
- const Ttab * n_gt_i = n_gt.data();
488
+ const Ttab* n_gt_i = n_gt.data();
497
489
  for (int i = 0; i < nc; i++) {
498
- int ip0 = perm [i];
499
- int ip = perm [i == iw ? jw : i == jw ? iw : i];
490
+ int ip0 = perm[i];
491
+ int ip = perm[i == iw ? jw : i == jw ? iw : i];
500
492
 
501
- //accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
493
+ // accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
502
494
 
503
- accu += update_i_cross (perm, iw, jw,
504
- ip0, ip, n_gt_i);
495
+ accu += update_i_cross(perm, iw, jw, ip0, ip, n_gt_i);
505
496
 
506
497
  if (ip != ip0)
507
- accu += update_i_plane (perm, iw, jw,
508
- ip0, ip, n_gt_i);
498
+ accu += update_i_plane(perm, iw, jw, ip0, ip, n_gt_i);
509
499
 
510
500
  n_gt_i += nc * nc;
511
501
  }
@@ -513,23 +503,26 @@ struct Score3Computer: PermutationObjective {
513
503
  return accu;
514
504
  }
515
505
 
516
-
517
- Taccu update_i (const int *perm, int iw, int jw,
518
- int ip0, int ip, const Ttab * n_gt_i) const
519
- {
506
+ Taccu update_i(
507
+ const int* perm,
508
+ int iw,
509
+ int jw,
510
+ int ip0,
511
+ int ip,
512
+ const Ttab* n_gt_i) const {
520
513
  Taccu accu = 0;
521
- const Ttab *n_gt_ij = n_gt_i;
514
+ const Ttab* n_gt_ij = n_gt_i;
522
515
  for (int j = 0; j < nc; j++) {
523
516
  int jp0 = perm[j];
524
- int jp = perm [j == iw ? jw : j == jw ? iw : j];
517
+ int jp = perm[j == iw ? jw : j == jw ? iw : j];
525
518
  for (int k = 0; k < nc; k++) {
526
- int kp0 = perm [k];
527
- int kp = perm [k == iw ? jw : k == jw ? iw : k];
528
- int ng = n_gt_ij [k];
529
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
519
+ int kp0 = perm[k];
520
+ int kp = perm[k == iw ? jw : k == jw ? iw : k];
521
+ int ng = n_gt_ij[k];
522
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
530
523
  accu += ng;
531
524
  }
532
- if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
525
+ if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
533
526
  accu -= ng;
534
527
  }
535
528
  }
@@ -539,23 +532,27 @@ struct Score3Computer: PermutationObjective {
539
532
  }
540
533
 
541
534
  // 2 inner loops for the case ip0 != ip
542
- Taccu update_i_plane (const int *perm, int iw, int jw,
543
- int ip0, int ip, const Ttab * n_gt_i) const
544
- {
535
+ Taccu update_i_plane(
536
+ const int* perm,
537
+ int iw,
538
+ int jw,
539
+ int ip0,
540
+ int ip,
541
+ const Ttab* n_gt_i) const {
545
542
  Taccu accu = 0;
546
- const Ttab *n_gt_ij = n_gt_i;
543
+ const Ttab* n_gt_ij = n_gt_i;
547
544
 
548
545
  for (int j = 0; j < nc; j++) {
549
546
  if (j != iw && j != jw) {
550
547
  int jp = perm[j];
551
548
  for (int k = 0; k < nc; k++) {
552
549
  if (k != iw && k != jw) {
553
- int kp = perm [k];
554
- Ttab ng = n_gt_ij [k];
555
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
550
+ int kp = perm[k];
551
+ Ttab ng = n_gt_ij[k];
552
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
556
553
  accu += ng;
557
554
  }
558
- if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) {
555
+ if (hamming_dis(ip0, jp) < hamming_dis(ip0, kp)) {
559
556
  accu -= ng;
560
557
  }
561
558
  }
@@ -567,114 +564,128 @@ struct Score3Computer: PermutationObjective {
567
564
  }
568
565
 
569
566
  /// used for the 8 cells were the 3 indices are swapped
570
- inline Taccu update_k (const int *perm, int iw, int jw,
571
- int ip0, int ip, int jp0, int jp,
572
- int k,
573
- const Ttab * n_gt_ij) const
574
- {
567
+ inline Taccu update_k(
568
+ const int* perm,
569
+ int iw,
570
+ int jw,
571
+ int ip0,
572
+ int ip,
573
+ int jp0,
574
+ int jp,
575
+ int k,
576
+ const Ttab* n_gt_ij) const {
575
577
  Taccu accu = 0;
576
- int kp0 = perm [k];
577
- int kp = perm [k == iw ? jw : k == jw ? iw : k];
578
- Ttab ng = n_gt_ij [k];
579
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
578
+ int kp0 = perm[k];
579
+ int kp = perm[k == iw ? jw : k == jw ? iw : k];
580
+ Ttab ng = n_gt_ij[k];
581
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
580
582
  accu += ng;
581
583
  }
582
- if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
584
+ if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
583
585
  accu -= ng;
584
586
  }
585
587
  return accu;
586
588
  }
587
589
 
588
590
  /// compute update on a line of k's, where i and j are swapped
589
- Taccu update_j_line (const int *perm, int iw, int jw,
590
- int ip0, int ip, int jp0, int jp,
591
- const Ttab * n_gt_ij) const
592
- {
591
+ Taccu update_j_line(
592
+ const int* perm,
593
+ int iw,
594
+ int jw,
595
+ int ip0,
596
+ int ip,
597
+ int jp0,
598
+ int jp,
599
+ const Ttab* n_gt_ij) const {
593
600
  Taccu accu = 0;
594
601
  for (int k = 0; k < nc; k++) {
595
- if (k == iw || k == jw) continue;
596
- int kp = perm [k];
597
- Ttab ng = n_gt_ij [k];
598
- if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
602
+ if (k == iw || k == jw)
603
+ continue;
604
+ int kp = perm[k];
605
+ Ttab ng = n_gt_ij[k];
606
+ if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
599
607
  accu += ng;
600
608
  }
601
- if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) {
609
+ if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp)) {
602
610
  accu -= ng;
603
611
  }
604
612
  }
605
613
  return accu;
606
614
  }
607
615
 
608
-
609
616
  /// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
610
- Taccu update_i_cross (const int *perm, int iw, int jw,
611
- int ip0, int ip, const Ttab * n_gt_i) const
612
- {
617
+ Taccu update_i_cross(
618
+ const int* perm,
619
+ int iw,
620
+ int jw,
621
+ int ip0,
622
+ int ip,
623
+ const Ttab* n_gt_i) const {
613
624
  Taccu accu = 0;
614
- const Ttab *n_gt_ij = n_gt_i;
625
+ const Ttab* n_gt_ij = n_gt_i;
615
626
 
616
627
  for (int j = 0; j < nc; j++) {
617
628
  int jp0 = perm[j];
618
- int jp = perm [j == iw ? jw : j == jw ? iw : j];
629
+ int jp = perm[j == iw ? jw : j == jw ? iw : j];
619
630
 
620
- accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
621
- accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
631
+ accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
632
+ accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
622
633
 
623
634
  if (jp != jp0)
624
- accu += update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
635
+ accu += update_j_line(perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
625
636
 
626
637
  n_gt_ij += nc;
627
638
  }
628
639
  return accu;
629
640
  }
630
641
 
631
-
632
642
  /// PermutationObjective implementeation (just negates the scores
633
643
  /// for minimization)
634
644
 
635
645
  double compute_cost(const int* perm) const override {
636
- return -compute(perm);
646
+ return -compute(perm);
637
647
  }
638
648
 
639
649
  double cost_update(const int* perm, int iw, int jw) const override {
640
- double ret = -compute_update(perm, iw, jw);
641
- return ret;
650
+ double ret = -compute_update(perm, iw, jw);
651
+ return ret;
642
652
  }
643
653
 
644
654
  ~Score3Computer() override {}
645
655
  };
646
656
 
647
-
648
-
649
-
650
-
651
657
  struct IndirectSort {
652
- const float *tab;
653
- bool operator () (int a, int b) {return tab[a] < tab[b]; }
658
+ const float* tab;
659
+ bool operator()(int a, int b) {
660
+ return tab[a] < tab[b];
661
+ }
654
662
  };
655
663
 
656
-
657
-
658
- struct RankingScore2: Score3Computer<float, double> {
664
+ struct RankingScore2 : Score3Computer<float, double> {
659
665
  int nbits;
660
666
  int nq, nb;
661
667
  const uint32_t *qcodes, *bcodes;
662
- const float *gt_distances;
663
-
664
- RankingScore2 (int nbits, int nq, int nb,
665
- const uint32_t *qcodes, const uint32_t *bcodes,
666
- const float *gt_distances):
667
- nbits(nbits), nq(nq), nb(nb), qcodes(qcodes),
668
- bcodes(bcodes), gt_distances(gt_distances)
669
- {
668
+ const float* gt_distances;
669
+
670
+ RankingScore2(
671
+ int nbits,
672
+ int nq,
673
+ int nb,
674
+ const uint32_t* qcodes,
675
+ const uint32_t* bcodes,
676
+ const float* gt_distances)
677
+ : nbits(nbits),
678
+ nq(nq),
679
+ nb(nb),
680
+ qcodes(qcodes),
681
+ bcodes(bcodes),
682
+ gt_distances(gt_distances) {
670
683
  n = nc = 1 << nbits;
671
- n_gt.resize (nc * nc * nc);
672
- init_n_gt ();
684
+ n_gt.resize(nc * nc * nc);
685
+ init_n_gt();
673
686
  }
674
687
 
675
-
676
- double rank_weight (int r)
677
- {
688
+ double rank_weight(int r) {
678
689
  return 1.0 / (r + 1);
679
690
  }
680
691
 
@@ -683,271 +694,290 @@ struct RankingScore2: Score3Computer<float, double> {
683
694
  /// they are the ranks of j and k respectively.
684
695
  /// specific version for diff-of-rank weighting, cannot optimized
685
696
  /// with a cumulative table
686
- double accum_gt_weight_diff (const std::vector<int> & a,
687
- const std::vector<int> & b)
688
- {
697
+ double accum_gt_weight_diff(
698
+ const std::vector<int>& a,
699
+ const std::vector<int>& b) {
689
700
  int nb = b.size(), na = a.size();
690
701
 
691
702
  double accu = 0;
692
703
  int j = 0;
693
704
  for (int i = 0; i < na; i++) {
694
705
  int ai = a[i];
695
- while (j < nb && ai >= b[j]) j++;
706
+ while (j < nb && ai >= b[j])
707
+ j++;
696
708
 
697
709
  double accu_i = 0;
698
710
  for (int k = j; k < b.size(); k++)
699
- accu_i += rank_weight (b[k] - ai);
700
-
701
- accu += rank_weight (ai) * accu_i;
711
+ accu_i += rank_weight(b[k] - ai);
702
712
 
713
+ accu += rank_weight(ai) * accu_i;
703
714
  }
704
715
  return accu;
705
716
  }
706
717
 
707
- void init_n_gt ()
708
- {
718
+ void init_n_gt() {
709
719
  for (int q = 0; q < nq; q++) {
710
- const float *gtd = gt_distances + q * nb;
711
- const uint32_t *cb = bcodes;// all same codes
712
- float * n_gt_q = & n_gt [qcodes[q] * nc * nc];
720
+ const float* gtd = gt_distances + q * nb;
721
+ const uint32_t* cb = bcodes; // all same codes
722
+ float* n_gt_q = &n_gt[qcodes[q] * nc * nc];
713
723
 
714
- printf("init gt for q=%d/%d \r", q, nq); fflush(stdout);
724
+ printf("init gt for q=%d/%d \r", q, nq);
725
+ fflush(stdout);
715
726
 
716
- std::vector<int> rankv (nb);
717
- int * ranks = rankv.data();
727
+ std::vector<int> rankv(nb);
728
+ int* ranks = rankv.data();
718
729
 
719
730
  // elements in each code bin, ordered by rank within each bin
720
- std::vector<std::vector<int> > tab (nc);
731
+ std::vector<std::vector<int>> tab(nc);
721
732
 
722
733
  { // build rank table
723
734
  IndirectSort s = {gtd};
724
- for (int j = 0; j < nb; j++) ranks[j] = j;
725
- std::sort (ranks, ranks + nb, s);
735
+ for (int j = 0; j < nb; j++)
736
+ ranks[j] = j;
737
+ std::sort(ranks, ranks + nb, s);
726
738
  }
727
739
 
728
740
  for (int rank = 0; rank < nb; rank++) {
729
- int i = ranks [rank];
730
- tab [cb[i]].push_back (rank);
741
+ int i = ranks[rank];
742
+ tab[cb[i]].push_back(rank);
731
743
  }
732
744
 
733
-
734
745
  // this is very expensive. Any suggestion for improvement
735
746
  // welcome.
736
747
  for (int i = 0; i < nc; i++) {
737
- std::vector<int> & di = tab[i];
748
+ std::vector<int>& di = tab[i];
738
749
  for (int j = 0; j < nc; j++) {
739
- std::vector<int> & dj = tab[j];
740
- n_gt_q [i * nc + j] += accum_gt_weight_diff (di, dj);
741
-
750
+ std::vector<int>& dj = tab[j];
751
+ n_gt_q[i * nc + j] += accum_gt_weight_diff(di, dj);
742
752
  }
743
753
  }
744
-
745
754
  }
746
-
747
755
  }
748
-
749
756
  };
750
757
 
751
-
752
758
  /*****************************************
753
759
  * PolysemousTraining
754
760
  ******************************************/
755
761
 
756
-
757
-
758
- PolysemousTraining::PolysemousTraining ()
759
- {
762
+ PolysemousTraining::PolysemousTraining() {
760
763
  optimization_type = OT_ReproduceDistances_affine;
761
764
  ntrain_permutation = 0;
762
765
  dis_weight_factor = log(2);
766
+ // max 20 G RAM
767
+ max_memory = (size_t)(20) * 1024 * 1024 * 1024;
763
768
  }
764
769
 
765
-
766
-
767
- void PolysemousTraining::optimize_reproduce_distances (
768
- ProductQuantizer &pq) const
769
- {
770
-
770
+ void PolysemousTraining::optimize_reproduce_distances(
771
+ ProductQuantizer& pq) const {
771
772
  int dsub = pq.dsub;
772
773
 
773
774
  int n = pq.ksub;
774
775
  int nbits = pq.nbits;
775
776
 
776
- #pragma omp parallel for
777
+ size_t mem1 = memory_usage_per_thread(pq);
778
+ int nt = std::min(omp_get_max_threads(), int(pq.M));
779
+ FAISS_THROW_IF_NOT_FMT(
780
+ mem1 < max_memory,
781
+ "Polysemous training will use %zd bytes per thread, while the max is set to %zd",
782
+ mem1,
783
+ max_memory);
784
+
785
+ if (mem1 * nt > max_memory) {
786
+ nt = max_memory / mem1;
787
+ fprintf(stderr,
788
+ "Polysemous training: WARN, reducing number of threads to %d to save memory",
789
+ nt);
790
+ }
791
+
792
+ #pragma omp parallel for num_threads(nt)
777
793
  for (int m = 0; m < pq.M; m++) {
778
794
  std::vector<double> dis_table;
779
795
 
780
796
  // printf ("Optimizing quantizer %d\n", m);
781
797
 
782
- float * centroids = pq.get_centroids (m, 0);
798
+ float* centroids = pq.get_centroids(m, 0);
783
799
 
784
800
  for (int i = 0; i < n; i++) {
785
801
  for (int j = 0; j < n; j++) {
786
- dis_table.push_back (fvec_L2sqr (centroids + i * dsub,
787
- centroids + j * dsub,
788
- dsub));
802
+ dis_table.push_back(fvec_L2sqr(
803
+ centroids + i * dsub, centroids + j * dsub, dsub));
789
804
  }
790
805
  }
791
806
 
792
- std::vector<int> perm (n);
793
- ReproduceWithHammingObjective obj (
794
- nbits, dis_table,
795
- dis_weight_factor);
796
-
807
+ std::vector<int> perm(n);
808
+ ReproduceWithHammingObjective obj(nbits, dis_table, dis_weight_factor);
797
809
 
798
- SimulatedAnnealingOptimizer optim (&obj, *this);
810
+ SimulatedAnnealingOptimizer optim(&obj, *this);
799
811
 
800
812
  if (log_pattern.size()) {
801
813
  char fname[256];
802
- snprintf (fname, 256, log_pattern.c_str(), m);
803
- printf ("opening log file %s\n", fname);
804
- optim.logfile = fopen (fname, "w");
805
- FAISS_THROW_IF_NOT_MSG (optim.logfile, "could not open logfile");
814
+ snprintf(fname, 256, log_pattern.c_str(), m);
815
+ printf("opening log file %s\n", fname);
816
+ optim.logfile = fopen(fname, "w");
817
+ FAISS_THROW_IF_NOT_MSG(optim.logfile, "could not open logfile");
806
818
  }
807
- double final_cost = optim.run_optimization (perm.data());
819
+ double final_cost = optim.run_optimization(perm.data());
808
820
 
809
821
  if (verbose > 0) {
810
- printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
811
- m, optim.init_cost, final_cost);
822
+ printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
823
+ m,
824
+ optim.init_cost,
825
+ final_cost);
812
826
  }
813
827
 
814
- if (log_pattern.size()) fclose (optim.logfile);
828
+ if (log_pattern.size())
829
+ fclose(optim.logfile);
815
830
 
816
831
  std::vector<float> centroids_copy;
817
832
  for (int i = 0; i < dsub * n; i++)
818
- centroids_copy.push_back (centroids[i]);
833
+ centroids_copy.push_back(centroids[i]);
819
834
 
820
835
  for (int i = 0; i < n; i++)
821
- memcpy (centroids + perm[i] * dsub,
822
- centroids_copy.data() + i * dsub,
823
- dsub * sizeof(centroids[0]));
824
-
836
+ memcpy(centroids + perm[i] * dsub,
837
+ centroids_copy.data() + i * dsub,
838
+ dsub * sizeof(centroids[0]));
825
839
  }
826
-
827
840
  }
828
841
 
829
-
830
- void PolysemousTraining::optimize_ranking (
831
- ProductQuantizer &pq, size_t n, const float *x) const
832
- {
833
-
842
+ void PolysemousTraining::optimize_ranking(
843
+ ProductQuantizer& pq,
844
+ size_t n,
845
+ const float* x) const {
834
846
  int dsub = pq.dsub;
835
-
836
847
  int nbits = pq.nbits;
837
848
 
838
- std::vector<uint8_t> all_codes (pq.code_size * n);
849
+ std::vector<uint8_t> all_codes(pq.code_size * n);
839
850
 
840
- pq.compute_codes (x, all_codes.data(), n);
851
+ pq.compute_codes(x, all_codes.data(), n);
841
852
 
842
- FAISS_THROW_IF_NOT (pq.nbits == 8);
853
+ FAISS_THROW_IF_NOT(pq.nbits == 8);
843
854
 
844
- if (n == 0)
845
- pq.compute_sdc_table ();
855
+ if (n == 0) {
856
+ pq.compute_sdc_table();
857
+ }
846
858
 
847
859
  #pragma omp parallel for
848
860
  for (int m = 0; m < pq.M; m++) {
849
861
  size_t nq, nb;
850
- std::vector <uint32_t> codes; // query codes, then db codes
851
- std::vector <float> gt_distances; // nq * nb matrix of distances
862
+ std::vector<uint32_t> codes; // query codes, then db codes
863
+ std::vector<float> gt_distances; // nq * nb matrix of distances
852
864
 
853
865
  if (n > 0) {
854
- std::vector<float> xtrain (n * dsub);
866
+ std::vector<float> xtrain(n * dsub);
855
867
  for (int i = 0; i < n; i++)
856
- memcpy (xtrain.data() + i * dsub,
857
- x + i * pq.d + m * dsub,
858
- sizeof(float) * dsub);
868
+ memcpy(xtrain.data() + i * dsub,
869
+ x + i * pq.d + m * dsub,
870
+ sizeof(float) * dsub);
859
871
 
860
- codes.resize (n);
872
+ codes.resize(n);
861
873
  for (int i = 0; i < n; i++)
862
- codes [i] = all_codes [i * pq.code_size + m];
874
+ codes[i] = all_codes[i * pq.code_size + m];
863
875
 
864
- nq = n / 4; nb = n - nq;
865
- const float *xq = xtrain.data();
866
- const float *xb = xq + nq * dsub;
876
+ nq = n / 4;
877
+ nb = n - nq;
878
+ const float* xq = xtrain.data();
879
+ const float* xb = xq + nq * dsub;
867
880
 
868
- gt_distances.resize (nq * nb);
881
+ gt_distances.resize(nq * nb);
869
882
 
870
- pairwise_L2sqr (dsub,
871
- nq, xq,
872
- nb, xb,
873
- gt_distances.data());
883
+ pairwise_L2sqr(dsub, nq, xq, nb, xb, gt_distances.data());
874
884
  } else {
875
885
  nq = nb = pq.ksub;
876
- codes.resize (2 * nq);
886
+ codes.resize(2 * nq);
877
887
  for (int i = 0; i < nq; i++)
878
- codes[i] = codes [i + nq] = i;
888
+ codes[i] = codes[i + nq] = i;
879
889
 
880
- gt_distances.resize (nq * nb);
890
+ gt_distances.resize(nq * nb);
881
891
 
882
- memcpy (gt_distances.data (),
883
- pq.sdc_table.data () + m * nq * nb,
884
- sizeof (float) * nq * nb);
892
+ memcpy(gt_distances.data(),
893
+ pq.sdc_table.data() + m * nq * nb,
894
+ sizeof(float) * nq * nb);
885
895
  }
886
896
 
887
- double t0 = getmillisecs ();
897
+ double t0 = getmillisecs();
888
898
 
889
- PermutationObjective *obj = new RankingScore2 (
890
- nbits, nq, nb,
891
- codes.data(), codes.data() + nq,
892
- gt_distances.data ());
893
- ScopeDeleter1<PermutationObjective> del (obj);
899
+ PermutationObjective* obj = new RankingScore2(
900
+ nbits,
901
+ nq,
902
+ nb,
903
+ codes.data(),
904
+ codes.data() + nq,
905
+ gt_distances.data());
906
+ ScopeDeleter1<PermutationObjective> del(obj);
894
907
 
895
908
  if (verbose > 0) {
896
909
  printf(" m=%d, nq=%zd, nb=%zd, intialize RankingScore "
897
910
  "in %.3f ms\n",
898
- m, nq, nb, getmillisecs () - t0);
911
+ m,
912
+ nq,
913
+ nb,
914
+ getmillisecs() - t0);
899
915
  }
900
916
 
901
- SimulatedAnnealingOptimizer optim (obj, *this);
917
+ SimulatedAnnealingOptimizer optim(obj, *this);
902
918
 
903
919
  if (log_pattern.size()) {
904
920
  char fname[256];
905
- snprintf (fname, 256, log_pattern.c_str(), m);
906
- printf ("opening log file %s\n", fname);
907
- optim.logfile = fopen (fname, "w");
908
- FAISS_THROW_IF_NOT_FMT (optim.logfile,
909
- "could not open logfile %s", fname);
921
+ snprintf(fname, 256, log_pattern.c_str(), m);
922
+ printf("opening log file %s\n", fname);
923
+ optim.logfile = fopen(fname, "w");
924
+ FAISS_THROW_IF_NOT_FMT(
925
+ optim.logfile, "could not open logfile %s", fname);
910
926
  }
911
927
 
912
- std::vector<int> perm (pq.ksub);
928
+ std::vector<int> perm(pq.ksub);
913
929
 
914
- double final_cost = optim.run_optimization (perm.data());
915
- printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
916
- m, optim.init_cost, final_cost);
930
+ double final_cost = optim.run_optimization(perm.data());
931
+ printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
932
+ m,
933
+ optim.init_cost,
934
+ final_cost);
917
935
 
918
- if (log_pattern.size()) fclose (optim.logfile);
936
+ if (log_pattern.size())
937
+ fclose(optim.logfile);
919
938
 
920
- float * centroids = pq.get_centroids (m, 0);
939
+ float* centroids = pq.get_centroids(m, 0);
921
940
 
922
941
  std::vector<float> centroids_copy;
923
942
  for (int i = 0; i < dsub * pq.ksub; i++)
924
- centroids_copy.push_back (centroids[i]);
943
+ centroids_copy.push_back(centroids[i]);
925
944
 
926
945
  for (int i = 0; i < pq.ksub; i++)
927
- memcpy (centroids + perm[i] * dsub,
928
- centroids_copy.data() + i * dsub,
929
- dsub * sizeof(centroids[0]));
930
-
946
+ memcpy(centroids + perm[i] * dsub,
947
+ centroids_copy.data() + i * dsub,
948
+ dsub * sizeof(centroids[0]));
931
949
  }
932
-
933
950
  }
934
951
 
935
-
936
-
937
- void PolysemousTraining::optimize_pq_for_hamming (ProductQuantizer &pq,
938
- size_t n, const float *x) const
939
- {
952
+ void PolysemousTraining::optimize_pq_for_hamming(
953
+ ProductQuantizer& pq,
954
+ size_t n,
955
+ const float* x) const {
940
956
  if (optimization_type == OT_None) {
941
-
942
957
  } else if (optimization_type == OT_ReproduceDistances_affine) {
943
- optimize_reproduce_distances (pq);
958
+ optimize_reproduce_distances(pq);
944
959
  } else {
945
- optimize_ranking (pq, n, x);
960
+ optimize_ranking(pq, n, x);
946
961
  }
947
962
 
948
- pq.compute_sdc_table ();
949
-
963
+ pq.compute_sdc_table();
950
964
  }
951
965
 
966
+ size_t PolysemousTraining::memory_usage_per_thread(
967
+ const ProductQuantizer& pq) const {
968
+ size_t n = pq.ksub;
969
+
970
+ switch (optimization_type) {
971
+ case OT_None:
972
+ return 0;
973
+ case OT_ReproduceDistances_affine:
974
+ return n * n * sizeof(double) * 3;
975
+ case OT_Ranking_weighted_diff:
976
+ return n * n * n * sizeof(float);
977
+ }
978
+
979
+ FAISS_THROW_MSG("Invalid optmization type");
980
+ return 0;
981
+ }
952
982
 
953
983
  } // namespace faiss