faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,275 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <vector>
13
+ #include <unordered_set>
14
+ #include <queue>
15
+
16
+ #include <omp.h>
17
+
18
+ #include <faiss/Index.h>
19
+ #include <faiss/impl/FaissAssert.h>
20
+ #include <faiss/utils/random.h>
21
+ #include <faiss/utils/Heap.h>
22
+
23
+
24
+ namespace faiss {
25
+
26
+
27
+ /** Implementation of the Hierarchical Navigable Small World
28
+ * datastructure.
29
+ *
30
+ * Efficient and robust approximate nearest neighbor search using
31
+ * Hierarchical Navigable Small World graphs
32
+ *
33
+ * Yu. A. Malkov, D. A. Yashunin, arXiv 2017
34
+ *
35
+ * This implmentation is heavily influenced by the NMSlib
36
+ * implementation by Yury Malkov and Leonid Boystov
37
+ * (https://github.com/searchivarius/nmslib)
38
+ *
39
+ * The HNSW object stores only the neighbor link structure, see
40
+ * IndexHNSW.h for the full index object.
41
+ */
42
+
43
+
44
+ struct VisitedTable;
45
+ struct DistanceComputer; // from AuxIndexStructures
46
+
47
+ struct HNSW {
48
+ /// internal storage of vectors (32 bits: this is expensive)
49
+ typedef int storage_idx_t;
50
+
51
+ /// Faiss results are 64-bit
52
+ typedef Index::idx_t idx_t;
53
+
54
+ typedef std::pair<float, storage_idx_t> Node;
55
+
56
+ /** Heap structure that allows fast
57
+ */
58
+ struct MinimaxHeap {
59
+ int n;
60
+ int k;
61
+ int nvalid;
62
+
63
+ std::vector<storage_idx_t> ids;
64
+ std::vector<float> dis;
65
+ typedef faiss::CMax<float, storage_idx_t> HC;
66
+
67
+ explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {}
68
+
69
+ void push(storage_idx_t i, float v);
70
+
71
+ float max() const;
72
+
73
+ int size() const;
74
+
75
+ void clear();
76
+
77
+ int pop_min(float *vmin_out = nullptr);
78
+
79
+ int count_below(float thresh);
80
+ };
81
+
82
+
83
+ /// to sort pairs of (id, distance) from nearest to fathest or the reverse
84
+ struct NodeDistCloser {
85
+ float d;
86
+ int id;
87
+ NodeDistCloser(float d, int id): d(d), id(id) {}
88
+ bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; }
89
+ };
90
+
91
+ struct NodeDistFarther {
92
+ float d;
93
+ int id;
94
+ NodeDistFarther(float d, int id): d(d), id(id) {}
95
+ bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; }
96
+ };
97
+
98
+
99
+ /// assignment probability to each layer (sum=1)
100
+ std::vector<double> assign_probas;
101
+
102
+ /// number of neighbors stored per layer (cumulative), should not
103
+ /// be changed after first add
104
+ std::vector<int> cum_nneighbor_per_level;
105
+
106
+ /// level of each vector (base level = 1), size = ntotal
107
+ std::vector<int> levels;
108
+
109
+ /// offsets[i] is the offset in the neighbors array where vector i is stored
110
+ /// size ntotal + 1
111
+ std::vector<size_t> offsets;
112
+
113
+ /// neighbors[offsets[i]:offsets[i+1]] is the list of neighbors of vector i
114
+ /// for all levels. this is where all storage goes.
115
+ std::vector<storage_idx_t> neighbors;
116
+
117
+ /// entry point in the search structure (one of the points with maximum level
118
+ storage_idx_t entry_point;
119
+
120
+ faiss::RandomGenerator rng;
121
+
122
+ /// maximum level
123
+ int max_level;
124
+
125
+ /// expansion factor at construction time
126
+ int efConstruction;
127
+
128
+ /// expansion factor at search time
129
+ int efSearch;
130
+
131
+ /// during search: do we check whether the next best distance is good enough?
132
+ bool check_relative_distance = true;
133
+
134
+ /// number of entry points in levels > 0.
135
+ int upper_beam;
136
+
137
+ /// use bounded queue during exploration
138
+ bool search_bounded_queue = true;
139
+
140
+ // methods that initialize the tree sizes
141
+
142
+ /// initialize the assign_probas and cum_nneighbor_per_level to
143
+ /// have 2*M links on level 0 and M links on levels > 0
144
+ void set_default_probas(int M, float levelMult);
145
+
146
+ /// set nb of neighbors for this level (before adding anything)
147
+ void set_nb_neighbors(int level_no, int n);
148
+
149
+ // methods that access the tree sizes
150
+
151
+ /// nb of neighbors for this level
152
+ int nb_neighbors(int layer_no) const;
153
+
154
+ /// cumumlative nb up to (and excluding) this level
155
+ int cum_nb_neighbors(int layer_no) const;
156
+
157
+ /// range of entries in the neighbors table of vertex no at layer_no
158
+ void neighbor_range(idx_t no, int layer_no,
159
+ size_t * begin, size_t * end) const;
160
+
161
+ /// only mandatory parameter: nb of neighbors
162
+ explicit HNSW(int M = 32);
163
+
164
+ /// pick a random level for a new point
165
+ int random_level();
166
+
167
+ /// add n random levels to table (for debugging...)
168
+ void fill_with_random_links(size_t n);
169
+
170
+ void add_links_starting_from(DistanceComputer& ptdis,
171
+ storage_idx_t pt_id,
172
+ storage_idx_t nearest,
173
+ float d_nearest,
174
+ int level,
175
+ omp_lock_t *locks,
176
+ VisitedTable &vt);
177
+
178
+
179
+ /** add point pt_id on all levels <= pt_level and build the link
180
+ * structure for them. */
181
+ void add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
182
+ std::vector<omp_lock_t>& locks,
183
+ VisitedTable& vt);
184
+
185
+ int search_from_candidates(DistanceComputer& qdis, int k,
186
+ idx_t *I, float *D,
187
+ MinimaxHeap& candidates,
188
+ VisitedTable &vt,
189
+ int level, int nres_in = 0) const;
190
+
191
+ std::priority_queue<Node> search_from_candidate_unbounded(
192
+ const Node& node,
193
+ DistanceComputer& qdis,
194
+ int ef,
195
+ VisitedTable *vt
196
+ ) const;
197
+
198
+ /// search interface
199
+ void search(DistanceComputer& qdis, int k,
200
+ idx_t *I, float *D,
201
+ VisitedTable& vt) const;
202
+
203
+ void reset();
204
+
205
+ void clear_neighbor_tables(int level);
206
+ void print_neighbor_stats(int level) const;
207
+
208
+ int prepare_level_tab(size_t n, bool preset_levels = false);
209
+
210
+ static void shrink_neighbor_list(
211
+ DistanceComputer& qdis,
212
+ std::priority_queue<NodeDistFarther>& input,
213
+ std::vector<NodeDistFarther>& output,
214
+ int max_size);
215
+
216
+ };
217
+
218
+
219
+ /**************************************************************
220
+ * Auxiliary structures
221
+ **************************************************************/
222
+
223
+ /// set implementation optimized for fast access.
224
+ struct VisitedTable {
225
+ std::vector<uint8_t> visited;
226
+ int visno;
227
+
228
+ explicit VisitedTable(int size)
229
+ : visited(size), visno(1) {}
230
+
231
+ /// set flog #no to true
232
+ void set(int no) {
233
+ visited[no] = visno;
234
+ }
235
+
236
+ /// get flag #no
237
+ bool get(int no) const {
238
+ return visited[no] == visno;
239
+ }
240
+
241
+ /// reset all flags to false
242
+ void advance() {
243
+ visno++;
244
+ if (visno == 250) {
245
+ // 250 rather than 255 because sometimes we use visno and visno+1
246
+ memset(visited.data(), 0, sizeof(visited[0]) * visited.size());
247
+ visno = 1;
248
+ }
249
+ }
250
+ };
251
+
252
+
253
+ struct HNSWStats {
254
+ size_t n1, n2, n3;
255
+ size_t ndis;
256
+ size_t nreorder;
257
+ bool view;
258
+
259
+ HNSWStats() {
260
+ reset();
261
+ }
262
+
263
+ void reset() {
264
+ n1 = n2 = n3 = 0;
265
+ ndis = 0;
266
+ nreorder = 0;
267
+ view = false;
268
+ }
269
+ };
270
+
271
+ // global var that collects them all
272
+ extern HNSWStats hnsw_stats;
273
+
274
+
275
+ } // namespace faiss
@@ -0,0 +1,953 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/PolysemousTraining.h>
11
+
12
+ #include <cstdlib>
13
+ #include <cmath>
14
+ #include <cstring>
15
+ #include <stdint.h>
16
+
17
+ #include <algorithm>
18
+
19
+ #include <faiss/utils/random.h>
20
+ #include <faiss/utils/utils.h>
21
+ #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/hamming.h>
23
+
24
+ #include <faiss/impl/FaissAssert.h>
25
+
26
+ /*****************************************
27
+ * Mixed PQ / Hamming
28
+ ******************************************/
29
+
30
+ namespace faiss {
31
+
32
+
33
+ /****************************************************
34
+ * Optimization code
35
+ ****************************************************/
36
+
37
+ SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
38
+ {
39
+ // set some reasonable defaults for the optimization
40
+ init_temperature = 0.7;
41
+ temperature_decay = pow (0.9, 1/500.);
42
+ // reduce by a factor 0.9 every 500 it
43
+ n_iter = 500000;
44
+ n_redo = 2;
45
+ seed = 123;
46
+ verbose = 0;
47
+ only_bit_flips = false;
48
+ init_random = false;
49
+ }
50
+
51
+ // what would the cost update be if iw and jw were swapped?
52
+ // default implementation just computes both and computes the difference
53
+ double PermutationObjective::cost_update (
54
+ const int *perm, int iw, int jw) const
55
+ {
56
+ double orig_cost = compute_cost (perm);
57
+
58
+ std::vector<int> perm2 (n);
59
+ for (int i = 0; i < n; i++)
60
+ perm2[i] = perm[i];
61
+ perm2[iw] = perm[jw];
62
+ perm2[jw] = perm[iw];
63
+
64
+ double new_cost = compute_cost (perm2.data());
65
+ return new_cost - orig_cost;
66
+ }
67
+
68
+
69
+
70
+
71
+ SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer (
72
+ PermutationObjective *obj,
73
+ const SimulatedAnnealingParameters &p):
74
+ SimulatedAnnealingParameters (p),
75
+ obj (obj),
76
+ n(obj->n),
77
+ logfile (nullptr)
78
+ {
79
+ rnd = new RandomGenerator (p.seed);
80
+ FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
81
+ }
82
+
83
+ SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer ()
84
+ {
85
+ delete rnd;
86
+ }
87
+
88
+ // run the optimization and return the best result in best_perm
89
+ double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
90
+ {
91
+ double min_cost = 1e30;
92
+
93
+ // just do a few runs of the annealing and keep the lowest output cost
94
+ for (int it = 0; it < n_redo; it++) {
95
+ std::vector<int> perm(n);
96
+ for (int i = 0; i < n; i++)
97
+ perm[i] = i;
98
+ if (init_random) {
99
+ for (int i = 0; i < n; i++) {
100
+ int j = i + rnd->rand_int (n - i);
101
+ std::swap (perm[i], perm[j]);
102
+ }
103
+ }
104
+ float cost = optimize (perm.data());
105
+ if (logfile) fprintf (logfile, "\n");
106
+ if(verbose > 1) {
107
+ printf (" optimization run %d: cost=%g %s\n",
108
+ it, cost, cost < min_cost ? "keep" : "");
109
+ }
110
+ if (cost < min_cost) {
111
+ memcpy (best_perm, perm.data(), sizeof(perm[0]) * n);
112
+ min_cost = cost;
113
+ }
114
+ }
115
+ return min_cost;
116
+ }
117
+
118
+ // perform the optimization loop, starting from and modifying
119
+ // permutation in-place
120
+ double SimulatedAnnealingOptimizer::optimize (int *perm)
121
+ {
122
+ double cost = init_cost = obj->compute_cost (perm);
123
+ int log2n = 0;
124
+ while (!(n <= (1 << log2n))) log2n++;
125
+ double temperature = init_temperature;
126
+ int n_swap = 0, n_hot = 0;
127
+ for (int it = 0; it < n_iter; it++) {
128
+ temperature = temperature * temperature_decay;
129
+ int iw, jw;
130
+ if (only_bit_flips) {
131
+ iw = rnd->rand_int (n);
132
+ jw = iw ^ (1 << rnd->rand_int (log2n));
133
+ } else {
134
+ iw = rnd->rand_int (n);
135
+ jw = rnd->rand_int (n - 1);
136
+ if (jw == iw) jw++;
137
+ }
138
+ double delta_cost = obj->cost_update (perm, iw, jw);
139
+ if (delta_cost < 0 || rnd->rand_float () < temperature) {
140
+ std::swap (perm[iw], perm[jw]);
141
+ cost += delta_cost;
142
+ n_swap++;
143
+ if (delta_cost >= 0) n_hot++;
144
+ }
145
+ if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
146
+ printf (" iteration %d cost %g temp %g n_swap %d "
147
+ "(%d hot) \r",
148
+ it, cost, temperature, n_swap, n_hot);
149
+ fflush(stdout);
150
+ }
151
+ if (logfile) {
152
+ fprintf (logfile, "%d %g %g %d %d\n",
153
+ it, cost, temperature, n_swap, n_hot);
154
+ }
155
+ }
156
+ if (verbose > 1) printf("\n");
157
+ return cost;
158
+ }
159
+
160
+
161
+
162
+
163
+
164
+ /****************************************************
165
+ * Cost functions: ReproduceDistanceTable
166
+ ****************************************************/
167
+
168
+
169
+
170
+
171
+
172
+
173
+ static inline int hamming_dis (uint64_t a, uint64_t b)
174
+ {
175
+ return __builtin_popcountl (a ^ b);
176
+ }
177
+
178
+ namespace {
179
+
180
+ /// optimize permutation to reproduce a distance table with Hamming distances
181
+ struct ReproduceWithHammingObjective : PermutationObjective {
182
+ int nbits;
183
+ double dis_weight_factor;
184
+
185
+ static double sqr (double x) { return x * x; }
186
+
187
+
188
+ // weihgting of distances: it is more important to reproduce small
189
+ // distances well
190
+ double dis_weight (double x) const
191
+ {
192
+ return exp (-dis_weight_factor * x);
193
+ }
194
+
195
+ std::vector<double> target_dis; // wanted distances (size n^2)
196
+ std::vector<double> weights; // weights for each distance (size n^2)
197
+
198
+ // cost = quadratic difference between actual distance and Hamming distance
199
+ double compute_cost(const int* perm) const override {
200
+ double cost = 0;
201
+ for (int i = 0; i < n; i++) {
202
+ for (int j = 0; j < n; j++) {
203
+ double wanted = target_dis[i * n + j];
204
+ double w = weights[i * n + j];
205
+ double actual = hamming_dis(perm[i], perm[j]);
206
+ cost += w * sqr(wanted - actual);
207
+ }
208
+ }
209
+ return cost;
210
+ }
211
+
212
+
213
+ // what would the cost update be if iw and jw were swapped?
214
+ // computed in O(n) instead of O(n^2) for the full re-computation
215
+ double cost_update(const int* perm, int iw, int jw) const override {
216
+ double delta_cost = 0;
217
+
218
+ for (int i = 0; i < n; i++) {
219
+ if (i == iw) {
220
+ for (int j = 0; j < n; j++) {
221
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
222
+ double actual = hamming_dis(perm[i], perm[j]);
223
+ delta_cost -= w * sqr(wanted - actual);
224
+ double new_actual =
225
+ hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]);
226
+ delta_cost += w * sqr(wanted - new_actual);
227
+ }
228
+ } else if (i == jw) {
229
+ for (int j = 0; j < n; j++) {
230
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
231
+ double actual = hamming_dis(perm[i], perm[j]);
232
+ delta_cost -= w * sqr(wanted - actual);
233
+ double new_actual =
234
+ hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]);
235
+ delta_cost += w * sqr(wanted - new_actual);
236
+ }
237
+ } else {
238
+ int j = iw;
239
+ {
240
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
241
+ double actual = hamming_dis(perm[i], perm[j]);
242
+ delta_cost -= w * sqr(wanted - actual);
243
+ double new_actual = hamming_dis(perm[i], perm[jw]);
244
+ delta_cost += w * sqr(wanted - new_actual);
245
+ }
246
+ j = jw;
247
+ {
248
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
249
+ double actual = hamming_dis(perm[i], perm[j]);
250
+ delta_cost -= w * sqr(wanted - actual);
251
+ double new_actual = hamming_dis(perm[i], perm[iw]);
252
+ delta_cost += w * sqr(wanted - new_actual);
253
+ }
254
+ }
255
+ }
256
+
257
+ return delta_cost;
258
+ }
259
+
260
+
261
+
262
+ ReproduceWithHammingObjective (
263
+ int nbits,
264
+ const std::vector<double> & dis_table,
265
+ double dis_weight_factor):
266
+ nbits (nbits), dis_weight_factor (dis_weight_factor)
267
+ {
268
+ n = 1 << nbits;
269
+ FAISS_THROW_IF_NOT (dis_table.size() == n * n);
270
+ set_affine_target_dis (dis_table);
271
+ }
272
+
273
+ void set_affine_target_dis (const std::vector<double> & dis_table)
274
+ {
275
+ double sum = 0, sum2 = 0;
276
+ int n2 = n * n;
277
+ for (int i = 0; i < n2; i++) {
278
+ sum += dis_table [i];
279
+ sum2 += dis_table [i] * dis_table [i];
280
+ }
281
+ double mean = sum / n2;
282
+ double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
283
+
284
+ target_dis.resize (n2);
285
+
286
+ for (int i = 0; i < n2; i++) {
287
+ // the mapping function
288
+ double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) +
289
+ nbits / 2;
290
+ target_dis[i] = td;
291
+ // compute a weight
292
+ weights.push_back (dis_weight (td));
293
+ }
294
+
295
+ }
296
+
297
+ ~ReproduceWithHammingObjective() override {}
298
+ };
299
+
300
+ } // anonymous namespace
301
+
302
+ // weihgting of distances: it is more important to reproduce small
303
+ // distances well
304
+ double ReproduceDistancesObjective::dis_weight (double x) const
305
+ {
306
+ return exp (-dis_weight_factor * x);
307
+ }
308
+
309
+
310
+ double ReproduceDistancesObjective::get_source_dis (int i, int j) const
311
+ {
312
+ return source_dis [i * n + j];
313
+ }
314
+
315
+ // cost = quadratic difference between actual distance and Hamming distance
316
+ double ReproduceDistancesObjective::compute_cost (const int *perm) const
317
+ {
318
+ double cost = 0;
319
+ for (int i = 0; i < n; i++) {
320
+ for (int j = 0; j < n; j++) {
321
+ double wanted = target_dis [i * n + j];
322
+ double w = weights [i * n + j];
323
+ double actual = get_source_dis (perm[i], perm[j]);
324
+ cost += w * sqr (wanted - actual);
325
+ }
326
+ }
327
+ return cost;
328
+ }
329
+
330
+ // what would the cost update be if iw and jw were swapped?
331
+ // computed in O(n) instead of O(n^2) for the full re-computation
332
+ double ReproduceDistancesObjective::cost_update(
333
+ const int *perm, int iw, int jw) const
334
+ {
335
+ double delta_cost = 0;
336
+ for (int i = 0; i < n; i++) {
337
+ if (i == iw) {
338
+ for (int j = 0; j < n; j++) {
339
+ double wanted = target_dis [i * n + j],
340
+ w = weights [i * n + j];
341
+ double actual = get_source_dis (perm[i], perm[j]);
342
+ delta_cost -= w * sqr (wanted - actual);
343
+ double new_actual = get_source_dis (
344
+ perm[jw],
345
+ perm[j == iw ? jw : j == jw ? iw : j]);
346
+ delta_cost += w * sqr (wanted - new_actual);
347
+ }
348
+ } else if (i == jw) {
349
+ for (int j = 0; j < n; j++) {
350
+ double wanted = target_dis [i * n + j],
351
+ w = weights [i * n + j];
352
+ double actual = get_source_dis (perm[i], perm[j]);
353
+ delta_cost -= w * sqr (wanted - actual);
354
+ double new_actual = get_source_dis (
355
+ perm[iw],
356
+ perm[j == iw ? jw : j == jw ? iw : j]);
357
+ delta_cost += w * sqr (wanted - new_actual);
358
+ }
359
+ } else {
360
+ int j = iw;
361
+ {
362
+ double wanted = target_dis [i * n + j],
363
+ w = weights [i * n + j];
364
+ double actual = get_source_dis (perm[i], perm[j]);
365
+ delta_cost -= w * sqr (wanted - actual);
366
+ double new_actual = get_source_dis (perm[i], perm[jw]);
367
+ delta_cost += w * sqr (wanted - new_actual);
368
+ }
369
+ j = jw;
370
+ {
371
+ double wanted = target_dis [i * n + j],
372
+ w = weights [i * n + j];
373
+ double actual = get_source_dis (perm[i], perm[j]);
374
+ delta_cost -= w * sqr (wanted - actual);
375
+ double new_actual = get_source_dis (perm[i], perm[iw]);
376
+ delta_cost += w * sqr (wanted - new_actual);
377
+ }
378
+ }
379
+ }
380
+ return delta_cost;
381
+ }
382
+
383
+
384
+
385
+ ReproduceDistancesObjective::ReproduceDistancesObjective (
386
+ int n,
387
+ const double *source_dis_in,
388
+ const double *target_dis_in,
389
+ double dis_weight_factor):
390
+ dis_weight_factor (dis_weight_factor),
391
+ target_dis (target_dis_in)
392
+ {
393
+ this->n = n;
394
+ set_affine_target_dis (source_dis_in);
395
+ }
396
+
397
+ void ReproduceDistancesObjective::compute_mean_stdev (
398
+ const double *tab, size_t n2,
399
+ double *mean_out, double *stddev_out)
400
+ {
401
+ double sum = 0, sum2 = 0;
402
+ for (int i = 0; i < n2; i++) {
403
+ sum += tab [i];
404
+ sum2 += tab [i] * tab [i];
405
+ }
406
+ double mean = sum / n2;
407
+ double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
408
+ *mean_out = mean;
409
+ *stddev_out = stddev;
410
+ }
411
+
412
+ void ReproduceDistancesObjective::set_affine_target_dis (
413
+ const double *source_dis_in)
414
+ {
415
+ int n2 = n * n;
416
+
417
+ double mean_src, stddev_src;
418
+ compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src);
419
+
420
+ double mean_target, stddev_target;
421
+ compute_mean_stdev (target_dis, n2, &mean_target, &stddev_target);
422
+
423
+ printf ("map mean %g std %g -> mean %g std %g\n",
424
+ mean_src, stddev_src, mean_target, stddev_target);
425
+
426
+ source_dis.resize (n2);
427
+ weights.resize (n2);
428
+
429
+ for (int i = 0; i < n2; i++) {
430
+ // the mapping function
431
+ source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src
432
+ * stddev_target + mean_target;
433
+
434
+ // compute a weight
435
+ weights [i] = dis_weight (target_dis[i]);
436
+ }
437
+
438
+ }
439
+
440
+ /****************************************************
441
+ * Cost functions: RankingScore
442
+ ****************************************************/
443
+
444
+ /// Maintains a 3D table of elementary costs.
445
+ /// Accumulates elements based on Hamming distance comparisons
446
+ template <typename Ttab, typename Taccu>
447
+ struct Score3Computer: PermutationObjective {
448
+
449
+ int nc;
450
+
451
+ // cost matrix of size nc * nc *nc
452
+ // n_gt (i,j,k) = count of d_gt(x, y-) < d_gt(x, y+)
453
+ // where x has PQ code i, y- PQ code j and y+ PQ code k
454
+ std::vector<Ttab> n_gt;
455
+
456
+
457
+ /// the cost is a triple loop on the nc * nc * nc matrix of entries.
458
+ ///
459
+ Taccu compute (const int * perm) const
460
+ {
461
+ Taccu accu = 0;
462
+ const Ttab *p = n_gt.data();
463
+ for (int i = 0; i < nc; i++) {
464
+ int ip = perm [i];
465
+ for (int j = 0; j < nc; j++) {
466
+ int jp = perm [j];
467
+ for (int k = 0; k < nc; k++) {
468
+ int kp = perm [k];
469
+ if (hamming_dis (ip, jp) <
470
+ hamming_dis (ip, kp)) {
471
+ accu += *p; // n_gt [ ( i * nc + j) * nc + k];
472
+ }
473
+ p++;
474
+ }
475
+ }
476
+ }
477
+ return accu;
478
+ }
479
+
480
+
481
+ /** cost update if entries iw and jw of the permutation would be
482
+ * swapped.
483
+ *
484
+ * The computation is optimized by avoiding elements in the
485
+ * nc*nc*nc cube that are known not to change. For nc=256, this
486
+ * reduces the nb of cells to visit to about 6/256 th of the
487
+ * cells. Practical speedup is about 8x, and the code is quite
488
+ * complex :-/
489
+ */
490
+ Taccu compute_update (const int *perm, int iw, int jw) const
491
+ {
492
+ assert (iw != jw);
493
+ if (iw > jw) std::swap (iw, jw);
494
+
495
+ Taccu accu = 0;
496
+ const Ttab * n_gt_i = n_gt.data();
497
+ for (int i = 0; i < nc; i++) {
498
+ int ip0 = perm [i];
499
+ int ip = perm [i == iw ? jw : i == jw ? iw : i];
500
+
501
+ //accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
502
+
503
+ accu += update_i_cross (perm, iw, jw,
504
+ ip0, ip, n_gt_i);
505
+
506
+ if (ip != ip0)
507
+ accu += update_i_plane (perm, iw, jw,
508
+ ip0, ip, n_gt_i);
509
+
510
+ n_gt_i += nc * nc;
511
+ }
512
+
513
+ return accu;
514
+ }
515
+
516
+
517
+ Taccu update_i (const int *perm, int iw, int jw,
518
+ int ip0, int ip, const Ttab * n_gt_i) const
519
+ {
520
+ Taccu accu = 0;
521
+ const Ttab *n_gt_ij = n_gt_i;
522
+ for (int j = 0; j < nc; j++) {
523
+ int jp0 = perm[j];
524
+ int jp = perm [j == iw ? jw : j == jw ? iw : j];
525
+ for (int k = 0; k < nc; k++) {
526
+ int kp0 = perm [k];
527
+ int kp = perm [k == iw ? jw : k == jw ? iw : k];
528
+ int ng = n_gt_ij [k];
529
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
530
+ accu += ng;
531
+ }
532
+ if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
533
+ accu -= ng;
534
+ }
535
+ }
536
+ n_gt_ij += nc;
537
+ }
538
+ return accu;
539
+ }
540
+
541
+ // 2 inner loops for the case ip0 != ip
542
+ Taccu update_i_plane (const int *perm, int iw, int jw,
543
+ int ip0, int ip, const Ttab * n_gt_i) const
544
+ {
545
+ Taccu accu = 0;
546
+ const Ttab *n_gt_ij = n_gt_i;
547
+
548
+ for (int j = 0; j < nc; j++) {
549
+ if (j != iw && j != jw) {
550
+ int jp = perm[j];
551
+ for (int k = 0; k < nc; k++) {
552
+ if (k != iw && k != jw) {
553
+ int kp = perm [k];
554
+ Ttab ng = n_gt_ij [k];
555
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
556
+ accu += ng;
557
+ }
558
+ if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) {
559
+ accu -= ng;
560
+ }
561
+ }
562
+ }
563
+ }
564
+ n_gt_ij += nc;
565
+ }
566
+ return accu;
567
+ }
568
+
569
+ /// used for the 8 cells were the 3 indices are swapped
570
+ inline Taccu update_k (const int *perm, int iw, int jw,
571
+ int ip0, int ip, int jp0, int jp,
572
+ int k,
573
+ const Ttab * n_gt_ij) const
574
+ {
575
+ Taccu accu = 0;
576
+ int kp0 = perm [k];
577
+ int kp = perm [k == iw ? jw : k == jw ? iw : k];
578
+ Ttab ng = n_gt_ij [k];
579
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
580
+ accu += ng;
581
+ }
582
+ if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
583
+ accu -= ng;
584
+ }
585
+ return accu;
586
+ }
587
+
588
+ /// compute update on a line of k's, where i and j are swapped
589
+ Taccu update_j_line (const int *perm, int iw, int jw,
590
+ int ip0, int ip, int jp0, int jp,
591
+ const Ttab * n_gt_ij) const
592
+ {
593
+ Taccu accu = 0;
594
+ for (int k = 0; k < nc; k++) {
595
+ if (k == iw || k == jw) continue;
596
+ int kp = perm [k];
597
+ Ttab ng = n_gt_ij [k];
598
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
599
+ accu += ng;
600
+ }
601
+ if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) {
602
+ accu -= ng;
603
+ }
604
+ }
605
+ return accu;
606
+ }
607
+
608
+
609
+ /// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
610
+ Taccu update_i_cross (const int *perm, int iw, int jw,
611
+ int ip0, int ip, const Ttab * n_gt_i) const
612
+ {
613
+ Taccu accu = 0;
614
+ const Ttab *n_gt_ij = n_gt_i;
615
+
616
+ for (int j = 0; j < nc; j++) {
617
+ int jp0 = perm[j];
618
+ int jp = perm [j == iw ? jw : j == jw ? iw : j];
619
+
620
+ accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
621
+ accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
622
+
623
+ if (jp != jp0)
624
+ accu += update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
625
+
626
+ n_gt_ij += nc;
627
+ }
628
+ return accu;
629
+ }
630
+
631
+
632
+ /// PermutationObjective implementeation (just negates the scores
633
+ /// for minimization)
634
+
635
+ double compute_cost(const int* perm) const override {
636
+ return -compute(perm);
637
+ }
638
+
639
+ double cost_update(const int* perm, int iw, int jw) const override {
640
+ double ret = -compute_update(perm, iw, jw);
641
+ return ret;
642
+ }
643
+
644
+ ~Score3Computer() override {}
645
+ };
646
+
647
+
648
+
649
+
650
+
651
+ struct IndirectSort {
652
+ const float *tab;
653
+ bool operator () (int a, int b) {return tab[a] < tab[b]; }
654
+ };
655
+
656
+
657
+
658
+ struct RankingScore2: Score3Computer<float, double> {
659
+ int nbits;
660
+ int nq, nb;
661
+ const uint32_t *qcodes, *bcodes;
662
+ const float *gt_distances;
663
+
664
+ RankingScore2 (int nbits, int nq, int nb,
665
+ const uint32_t *qcodes, const uint32_t *bcodes,
666
+ const float *gt_distances):
667
+ nbits(nbits), nq(nq), nb(nb), qcodes(qcodes),
668
+ bcodes(bcodes), gt_distances(gt_distances)
669
+ {
670
+ n = nc = 1 << nbits;
671
+ n_gt.resize (nc * nc * nc);
672
+ init_n_gt ();
673
+ }
674
+
675
+
676
+ double rank_weight (int r)
677
+ {
678
+ return 1.0 / (r + 1);
679
+ }
680
+
681
+ /// count nb of i, j in a x b st. i < j
682
+ /// a and b should be sorted on input
683
+ /// they are the ranks of j and k respectively.
684
+ /// specific version for diff-of-rank weighting, cannot optimized
685
+ /// with a cumulative table
686
+ double accum_gt_weight_diff (const std::vector<int> & a,
687
+ const std::vector<int> & b)
688
+ {
689
+ int nb = b.size(), na = a.size();
690
+
691
+ double accu = 0;
692
+ int j = 0;
693
+ for (int i = 0; i < na; i++) {
694
+ int ai = a[i];
695
+ while (j < nb && ai >= b[j]) j++;
696
+
697
+ double accu_i = 0;
698
+ for (int k = j; k < b.size(); k++)
699
+ accu_i += rank_weight (b[k] - ai);
700
+
701
+ accu += rank_weight (ai) * accu_i;
702
+
703
+ }
704
+ return accu;
705
+ }
706
+
707
+ void init_n_gt ()
708
+ {
709
+ for (int q = 0; q < nq; q++) {
710
+ const float *gtd = gt_distances + q * nb;
711
+ const uint32_t *cb = bcodes;// all same codes
712
+ float * n_gt_q = & n_gt [qcodes[q] * nc * nc];
713
+
714
+ printf("init gt for q=%d/%d \r", q, nq); fflush(stdout);
715
+
716
+ std::vector<int> rankv (nb);
717
+ int * ranks = rankv.data();
718
+
719
+ // elements in each code bin, ordered by rank within each bin
720
+ std::vector<std::vector<int> > tab (nc);
721
+
722
+ { // build rank table
723
+ IndirectSort s = {gtd};
724
+ for (int j = 0; j < nb; j++) ranks[j] = j;
725
+ std::sort (ranks, ranks + nb, s);
726
+ }
727
+
728
+ for (int rank = 0; rank < nb; rank++) {
729
+ int i = ranks [rank];
730
+ tab [cb[i]].push_back (rank);
731
+ }
732
+
733
+
734
+ // this is very expensive. Any suggestion for improvement
735
+ // welcome.
736
+ for (int i = 0; i < nc; i++) {
737
+ std::vector<int> & di = tab[i];
738
+ for (int j = 0; j < nc; j++) {
739
+ std::vector<int> & dj = tab[j];
740
+ n_gt_q [i * nc + j] += accum_gt_weight_diff (di, dj);
741
+
742
+ }
743
+ }
744
+
745
+ }
746
+
747
+ }
748
+
749
+ };
750
+
751
+
752
+ /*****************************************
753
+ * PolysemousTraining
754
+ ******************************************/
755
+
756
+
757
+
758
+ PolysemousTraining::PolysemousTraining ()
759
+ {
760
+ optimization_type = OT_ReproduceDistances_affine;
761
+ ntrain_permutation = 0;
762
+ dis_weight_factor = log(2);
763
+ }
764
+
765
+
766
+
767
+ void PolysemousTraining::optimize_reproduce_distances (
768
+ ProductQuantizer &pq) const
769
+ {
770
+
771
+ int dsub = pq.dsub;
772
+
773
+ int n = pq.ksub;
774
+ int nbits = pq.nbits;
775
+
776
+ #pragma omp parallel for
777
+ for (int m = 0; m < pq.M; m++) {
778
+ std::vector<double> dis_table;
779
+
780
+ // printf ("Optimizing quantizer %d\n", m);
781
+
782
+ float * centroids = pq.get_centroids (m, 0);
783
+
784
+ for (int i = 0; i < n; i++) {
785
+ for (int j = 0; j < n; j++) {
786
+ dis_table.push_back (fvec_L2sqr (centroids + i * dsub,
787
+ centroids + j * dsub,
788
+ dsub));
789
+ }
790
+ }
791
+
792
+ std::vector<int> perm (n);
793
+ ReproduceWithHammingObjective obj (
794
+ nbits, dis_table,
795
+ dis_weight_factor);
796
+
797
+
798
+ SimulatedAnnealingOptimizer optim (&obj, *this);
799
+
800
+ if (log_pattern.size()) {
801
+ char fname[256];
802
+ snprintf (fname, 256, log_pattern.c_str(), m);
803
+ printf ("opening log file %s\n", fname);
804
+ optim.logfile = fopen (fname, "w");
805
+ FAISS_THROW_IF_NOT_MSG (optim.logfile, "could not open logfile");
806
+ }
807
+ double final_cost = optim.run_optimization (perm.data());
808
+
809
+ if (verbose > 0) {
810
+ printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
811
+ m, optim.init_cost, final_cost);
812
+ }
813
+
814
+ if (log_pattern.size()) fclose (optim.logfile);
815
+
816
+ std::vector<float> centroids_copy;
817
+ for (int i = 0; i < dsub * n; i++)
818
+ centroids_copy.push_back (centroids[i]);
819
+
820
+ for (int i = 0; i < n; i++)
821
+ memcpy (centroids + perm[i] * dsub,
822
+ centroids_copy.data() + i * dsub,
823
+ dsub * sizeof(centroids[0]));
824
+
825
+ }
826
+
827
+ }
828
+
829
+
830
+ void PolysemousTraining::optimize_ranking (
831
+ ProductQuantizer &pq, size_t n, const float *x) const
832
+ {
833
+
834
+ int dsub = pq.dsub;
835
+
836
+ int nbits = pq.nbits;
837
+
838
+ std::vector<uint8_t> all_codes (pq.code_size * n);
839
+
840
+ pq.compute_codes (x, all_codes.data(), n);
841
+
842
+ FAISS_THROW_IF_NOT (pq.nbits == 8);
843
+
844
+ if (n == 0)
845
+ pq.compute_sdc_table ();
846
+
847
+ #pragma omp parallel for
848
+ for (int m = 0; m < pq.M; m++) {
849
+ size_t nq, nb;
850
+ std::vector <uint32_t> codes; // query codes, then db codes
851
+ std::vector <float> gt_distances; // nq * nb matrix of distances
852
+
853
+ if (n > 0) {
854
+ std::vector<float> xtrain (n * dsub);
855
+ for (int i = 0; i < n; i++)
856
+ memcpy (xtrain.data() + i * dsub,
857
+ x + i * pq.d + m * dsub,
858
+ sizeof(float) * dsub);
859
+
860
+ codes.resize (n);
861
+ for (int i = 0; i < n; i++)
862
+ codes [i] = all_codes [i * pq.code_size + m];
863
+
864
+ nq = n / 4; nb = n - nq;
865
+ const float *xq = xtrain.data();
866
+ const float *xb = xq + nq * dsub;
867
+
868
+ gt_distances.resize (nq * nb);
869
+
870
+ pairwise_L2sqr (dsub,
871
+ nq, xq,
872
+ nb, xb,
873
+ gt_distances.data());
874
+ } else {
875
+ nq = nb = pq.ksub;
876
+ codes.resize (2 * nq);
877
+ for (int i = 0; i < nq; i++)
878
+ codes[i] = codes [i + nq] = i;
879
+
880
+ gt_distances.resize (nq * nb);
881
+
882
+ memcpy (gt_distances.data (),
883
+ pq.sdc_table.data () + m * nq * nb,
884
+ sizeof (float) * nq * nb);
885
+ }
886
+
887
+ double t0 = getmillisecs ();
888
+
889
+ PermutationObjective *obj = new RankingScore2 (
890
+ nbits, nq, nb,
891
+ codes.data(), codes.data() + nq,
892
+ gt_distances.data ());
893
+ ScopeDeleter1<PermutationObjective> del (obj);
894
+
895
+ if (verbose > 0) {
896
+ printf(" m=%d, nq=%ld, nb=%ld, intialize RankingScore "
897
+ "in %.3f ms\n",
898
+ m, nq, nb, getmillisecs () - t0);
899
+ }
900
+
901
+ SimulatedAnnealingOptimizer optim (obj, *this);
902
+
903
+ if (log_pattern.size()) {
904
+ char fname[256];
905
+ snprintf (fname, 256, log_pattern.c_str(), m);
906
+ printf ("opening log file %s\n", fname);
907
+ optim.logfile = fopen (fname, "w");
908
+ FAISS_THROW_IF_NOT_FMT (optim.logfile,
909
+ "could not open logfile %s", fname);
910
+ }
911
+
912
+ std::vector<int> perm (pq.ksub);
913
+
914
+ double final_cost = optim.run_optimization (perm.data());
915
+ printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
916
+ m, optim.init_cost, final_cost);
917
+
918
+ if (log_pattern.size()) fclose (optim.logfile);
919
+
920
+ float * centroids = pq.get_centroids (m, 0);
921
+
922
+ std::vector<float> centroids_copy;
923
+ for (int i = 0; i < dsub * pq.ksub; i++)
924
+ centroids_copy.push_back (centroids[i]);
925
+
926
+ for (int i = 0; i < pq.ksub; i++)
927
+ memcpy (centroids + perm[i] * dsub,
928
+ centroids_copy.data() + i * dsub,
929
+ dsub * sizeof(centroids[0]));
930
+
931
+ }
932
+
933
+ }
934
+
935
+
936
+
937
+ void PolysemousTraining::optimize_pq_for_hamming (ProductQuantizer &pq,
938
+ size_t n, const float *x) const
939
+ {
940
+ if (optimization_type == OT_None) {
941
+
942
+ } else if (optimization_type == OT_ReproduceDistances_affine) {
943
+ optimize_reproduce_distances (pq);
944
+ } else {
945
+ optimize_ranking (pq, n, x);
946
+ }
947
+
948
+ pq.compute_sdc_table ();
949
+
950
+ }
951
+
952
+
953
+ } // namespace faiss