faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,275 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <vector>
13
+ #include <unordered_set>
14
+ #include <queue>
15
+
16
+ #include <omp.h>
17
+
18
+ #include <faiss/Index.h>
19
+ #include <faiss/impl/FaissAssert.h>
20
+ #include <faiss/utils/random.h>
21
+ #include <faiss/utils/Heap.h>
22
+
23
+
24
+ namespace faiss {
25
+
26
+
27
+ /** Implementation of the Hierarchical Navigable Small World
28
+ * datastructure.
29
+ *
30
+ * Efficient and robust approximate nearest neighbor search using
31
+ * Hierarchical Navigable Small World graphs
32
+ *
33
+ * Yu. A. Malkov, D. A. Yashunin, arXiv 2017
34
+ *
35
+ * This implmentation is heavily influenced by the NMSlib
36
+ * implementation by Yury Malkov and Leonid Boystov
37
+ * (https://github.com/searchivarius/nmslib)
38
+ *
39
+ * The HNSW object stores only the neighbor link structure, see
40
+ * IndexHNSW.h for the full index object.
41
+ */
42
+
43
+
44
+ struct VisitedTable;
45
+ struct DistanceComputer; // from AuxIndexStructures
46
+
47
+ struct HNSW {
48
+ /// internal storage of vectors (32 bits: this is expensive)
49
+ typedef int storage_idx_t;
50
+
51
+ /// Faiss results are 64-bit
52
+ typedef Index::idx_t idx_t;
53
+
54
+ typedef std::pair<float, storage_idx_t> Node;
55
+
56
+ /** Heap structure that allows fast
57
+ */
58
+ struct MinimaxHeap {
59
+ int n;
60
+ int k;
61
+ int nvalid;
62
+
63
+ std::vector<storage_idx_t> ids;
64
+ std::vector<float> dis;
65
+ typedef faiss::CMax<float, storage_idx_t> HC;
66
+
67
+ explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {}
68
+
69
+ void push(storage_idx_t i, float v);
70
+
71
+ float max() const;
72
+
73
+ int size() const;
74
+
75
+ void clear();
76
+
77
+ int pop_min(float *vmin_out = nullptr);
78
+
79
+ int count_below(float thresh);
80
+ };
81
+
82
+
83
+ /// to sort pairs of (id, distance) from nearest to fathest or the reverse
84
+ struct NodeDistCloser {
85
+ float d;
86
+ int id;
87
+ NodeDistCloser(float d, int id): d(d), id(id) {}
88
+ bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; }
89
+ };
90
+
91
+ struct NodeDistFarther {
92
+ float d;
93
+ int id;
94
+ NodeDistFarther(float d, int id): d(d), id(id) {}
95
+ bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; }
96
+ };
97
+
98
+
99
+ /// assignment probability to each layer (sum=1)
100
+ std::vector<double> assign_probas;
101
+
102
+ /// number of neighbors stored per layer (cumulative), should not
103
+ /// be changed after first add
104
+ std::vector<int> cum_nneighbor_per_level;
105
+
106
+ /// level of each vector (base level = 1), size = ntotal
107
+ std::vector<int> levels;
108
+
109
+ /// offsets[i] is the offset in the neighbors array where vector i is stored
110
+ /// size ntotal + 1
111
+ std::vector<size_t> offsets;
112
+
113
+ /// neighbors[offsets[i]:offsets[i+1]] is the list of neighbors of vector i
114
+ /// for all levels. this is where all storage goes.
115
+ std::vector<storage_idx_t> neighbors;
116
+
117
+ /// entry point in the search structure (one of the points with maximum level
118
+ storage_idx_t entry_point;
119
+
120
+ faiss::RandomGenerator rng;
121
+
122
+ /// maximum level
123
+ int max_level;
124
+
125
+ /// expansion factor at construction time
126
+ int efConstruction;
127
+
128
+ /// expansion factor at search time
129
+ int efSearch;
130
+
131
+ /// during search: do we check whether the next best distance is good enough?
132
+ bool check_relative_distance = true;
133
+
134
+ /// number of entry points in levels > 0.
135
+ int upper_beam;
136
+
137
+ /// use bounded queue during exploration
138
+ bool search_bounded_queue = true;
139
+
140
+ // methods that initialize the tree sizes
141
+
142
+ /// initialize the assign_probas and cum_nneighbor_per_level to
143
+ /// have 2*M links on level 0 and M links on levels > 0
144
+ void set_default_probas(int M, float levelMult);
145
+
146
+ /// set nb of neighbors for this level (before adding anything)
147
+ void set_nb_neighbors(int level_no, int n);
148
+
149
+ // methods that access the tree sizes
150
+
151
+ /// nb of neighbors for this level
152
+ int nb_neighbors(int layer_no) const;
153
+
154
+ /// cumumlative nb up to (and excluding) this level
155
+ int cum_nb_neighbors(int layer_no) const;
156
+
157
+ /// range of entries in the neighbors table of vertex no at layer_no
158
+ void neighbor_range(idx_t no, int layer_no,
159
+ size_t * begin, size_t * end) const;
160
+
161
+ /// only mandatory parameter: nb of neighbors
162
+ explicit HNSW(int M = 32);
163
+
164
+ /// pick a random level for a new point
165
+ int random_level();
166
+
167
+ /// add n random levels to table (for debugging...)
168
+ void fill_with_random_links(size_t n);
169
+
170
+ void add_links_starting_from(DistanceComputer& ptdis,
171
+ storage_idx_t pt_id,
172
+ storage_idx_t nearest,
173
+ float d_nearest,
174
+ int level,
175
+ omp_lock_t *locks,
176
+ VisitedTable &vt);
177
+
178
+
179
+ /** add point pt_id on all levels <= pt_level and build the link
180
+ * structure for them. */
181
+ void add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
182
+ std::vector<omp_lock_t>& locks,
183
+ VisitedTable& vt);
184
+
185
+ int search_from_candidates(DistanceComputer& qdis, int k,
186
+ idx_t *I, float *D,
187
+ MinimaxHeap& candidates,
188
+ VisitedTable &vt,
189
+ int level, int nres_in = 0) const;
190
+
191
+ std::priority_queue<Node> search_from_candidate_unbounded(
192
+ const Node& node,
193
+ DistanceComputer& qdis,
194
+ int ef,
195
+ VisitedTable *vt
196
+ ) const;
197
+
198
+ /// search interface
199
+ void search(DistanceComputer& qdis, int k,
200
+ idx_t *I, float *D,
201
+ VisitedTable& vt) const;
202
+
203
+ void reset();
204
+
205
+ void clear_neighbor_tables(int level);
206
+ void print_neighbor_stats(int level) const;
207
+
208
+ int prepare_level_tab(size_t n, bool preset_levels = false);
209
+
210
+ static void shrink_neighbor_list(
211
+ DistanceComputer& qdis,
212
+ std::priority_queue<NodeDistFarther>& input,
213
+ std::vector<NodeDistFarther>& output,
214
+ int max_size);
215
+
216
+ };
217
+
218
+
219
+ /**************************************************************
220
+ * Auxiliary structures
221
+ **************************************************************/
222
+
223
+ /// set implementation optimized for fast access.
224
+ struct VisitedTable {
225
+ std::vector<uint8_t> visited;
226
+ int visno;
227
+
228
+ explicit VisitedTable(int size)
229
+ : visited(size), visno(1) {}
230
+
231
+ /// set flog #no to true
232
+ void set(int no) {
233
+ visited[no] = visno;
234
+ }
235
+
236
+ /// get flag #no
237
+ bool get(int no) const {
238
+ return visited[no] == visno;
239
+ }
240
+
241
+ /// reset all flags to false
242
+ void advance() {
243
+ visno++;
244
+ if (visno == 250) {
245
+ // 250 rather than 255 because sometimes we use visno and visno+1
246
+ memset(visited.data(), 0, sizeof(visited[0]) * visited.size());
247
+ visno = 1;
248
+ }
249
+ }
250
+ };
251
+
252
+
253
+ struct HNSWStats {
254
+ size_t n1, n2, n3;
255
+ size_t ndis;
256
+ size_t nreorder;
257
+ bool view;
258
+
259
+ HNSWStats() {
260
+ reset();
261
+ }
262
+
263
+ void reset() {
264
+ n1 = n2 = n3 = 0;
265
+ ndis = 0;
266
+ nreorder = 0;
267
+ view = false;
268
+ }
269
+ };
270
+
271
+ // global var that collects them all
272
+ extern HNSWStats hnsw_stats;
273
+
274
+
275
+ } // namespace faiss
@@ -0,0 +1,953 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/PolysemousTraining.h>
11
+
12
+ #include <cstdlib>
13
+ #include <cmath>
14
+ #include <cstring>
15
+ #include <stdint.h>
16
+
17
+ #include <algorithm>
18
+
19
+ #include <faiss/utils/random.h>
20
+ #include <faiss/utils/utils.h>
21
+ #include <faiss/utils/distances.h>
22
+ #include <faiss/utils/hamming.h>
23
+
24
+ #include <faiss/impl/FaissAssert.h>
25
+
26
+ /*****************************************
27
+ * Mixed PQ / Hamming
28
+ ******************************************/
29
+
30
+ namespace faiss {
31
+
32
+
33
+ /****************************************************
34
+ * Optimization code
35
+ ****************************************************/
36
+
37
+ SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
38
+ {
39
+ // set some reasonable defaults for the optimization
40
+ init_temperature = 0.7;
41
+ temperature_decay = pow (0.9, 1/500.);
42
+ // reduce by a factor 0.9 every 500 it
43
+ n_iter = 500000;
44
+ n_redo = 2;
45
+ seed = 123;
46
+ verbose = 0;
47
+ only_bit_flips = false;
48
+ init_random = false;
49
+ }
50
+
51
+ // what would the cost update be if iw and jw were swapped?
52
+ // default implementation just computes both and computes the difference
53
+ double PermutationObjective::cost_update (
54
+ const int *perm, int iw, int jw) const
55
+ {
56
+ double orig_cost = compute_cost (perm);
57
+
58
+ std::vector<int> perm2 (n);
59
+ for (int i = 0; i < n; i++)
60
+ perm2[i] = perm[i];
61
+ perm2[iw] = perm[jw];
62
+ perm2[jw] = perm[iw];
63
+
64
+ double new_cost = compute_cost (perm2.data());
65
+ return new_cost - orig_cost;
66
+ }
67
+
68
+
69
+
70
+
71
+ SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer (
72
+ PermutationObjective *obj,
73
+ const SimulatedAnnealingParameters &p):
74
+ SimulatedAnnealingParameters (p),
75
+ obj (obj),
76
+ n(obj->n),
77
+ logfile (nullptr)
78
+ {
79
+ rnd = new RandomGenerator (p.seed);
80
+ FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
81
+ }
82
+
83
+ SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer ()
84
+ {
85
+ delete rnd;
86
+ }
87
+
88
+ // run the optimization and return the best result in best_perm
89
+ double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
90
+ {
91
+ double min_cost = 1e30;
92
+
93
+ // just do a few runs of the annealing and keep the lowest output cost
94
+ for (int it = 0; it < n_redo; it++) {
95
+ std::vector<int> perm(n);
96
+ for (int i = 0; i < n; i++)
97
+ perm[i] = i;
98
+ if (init_random) {
99
+ for (int i = 0; i < n; i++) {
100
+ int j = i + rnd->rand_int (n - i);
101
+ std::swap (perm[i], perm[j]);
102
+ }
103
+ }
104
+ float cost = optimize (perm.data());
105
+ if (logfile) fprintf (logfile, "\n");
106
+ if(verbose > 1) {
107
+ printf (" optimization run %d: cost=%g %s\n",
108
+ it, cost, cost < min_cost ? "keep" : "");
109
+ }
110
+ if (cost < min_cost) {
111
+ memcpy (best_perm, perm.data(), sizeof(perm[0]) * n);
112
+ min_cost = cost;
113
+ }
114
+ }
115
+ return min_cost;
116
+ }
117
+
118
+ // perform the optimization loop, starting from and modifying
119
+ // permutation in-place
120
+ double SimulatedAnnealingOptimizer::optimize (int *perm)
121
+ {
122
+ double cost = init_cost = obj->compute_cost (perm);
123
+ int log2n = 0;
124
+ while (!(n <= (1 << log2n))) log2n++;
125
+ double temperature = init_temperature;
126
+ int n_swap = 0, n_hot = 0;
127
+ for (int it = 0; it < n_iter; it++) {
128
+ temperature = temperature * temperature_decay;
129
+ int iw, jw;
130
+ if (only_bit_flips) {
131
+ iw = rnd->rand_int (n);
132
+ jw = iw ^ (1 << rnd->rand_int (log2n));
133
+ } else {
134
+ iw = rnd->rand_int (n);
135
+ jw = rnd->rand_int (n - 1);
136
+ if (jw == iw) jw++;
137
+ }
138
+ double delta_cost = obj->cost_update (perm, iw, jw);
139
+ if (delta_cost < 0 || rnd->rand_float () < temperature) {
140
+ std::swap (perm[iw], perm[jw]);
141
+ cost += delta_cost;
142
+ n_swap++;
143
+ if (delta_cost >= 0) n_hot++;
144
+ }
145
+ if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
146
+ printf (" iteration %d cost %g temp %g n_swap %d "
147
+ "(%d hot) \r",
148
+ it, cost, temperature, n_swap, n_hot);
149
+ fflush(stdout);
150
+ }
151
+ if (logfile) {
152
+ fprintf (logfile, "%d %g %g %d %d\n",
153
+ it, cost, temperature, n_swap, n_hot);
154
+ }
155
+ }
156
+ if (verbose > 1) printf("\n");
157
+ return cost;
158
+ }
159
+
160
+
161
+
162
+
163
+
164
+ /****************************************************
165
+ * Cost functions: ReproduceDistanceTable
166
+ ****************************************************/
167
+
168
+
169
+
170
+
171
+
172
+
173
+ static inline int hamming_dis (uint64_t a, uint64_t b)
174
+ {
175
+ return __builtin_popcountl (a ^ b);
176
+ }
177
+
178
+ namespace {
179
+
180
+ /// optimize permutation to reproduce a distance table with Hamming distances
181
+ struct ReproduceWithHammingObjective : PermutationObjective {
182
+ int nbits;
183
+ double dis_weight_factor;
184
+
185
+ static double sqr (double x) { return x * x; }
186
+
187
+
188
+ // weihgting of distances: it is more important to reproduce small
189
+ // distances well
190
+ double dis_weight (double x) const
191
+ {
192
+ return exp (-dis_weight_factor * x);
193
+ }
194
+
195
+ std::vector<double> target_dis; // wanted distances (size n^2)
196
+ std::vector<double> weights; // weights for each distance (size n^2)
197
+
198
+ // cost = quadratic difference between actual distance and Hamming distance
199
+ double compute_cost(const int* perm) const override {
200
+ double cost = 0;
201
+ for (int i = 0; i < n; i++) {
202
+ for (int j = 0; j < n; j++) {
203
+ double wanted = target_dis[i * n + j];
204
+ double w = weights[i * n + j];
205
+ double actual = hamming_dis(perm[i], perm[j]);
206
+ cost += w * sqr(wanted - actual);
207
+ }
208
+ }
209
+ return cost;
210
+ }
211
+
212
+
213
+ // what would the cost update be if iw and jw were swapped?
214
+ // computed in O(n) instead of O(n^2) for the full re-computation
215
+ double cost_update(const int* perm, int iw, int jw) const override {
216
+ double delta_cost = 0;
217
+
218
+ for (int i = 0; i < n; i++) {
219
+ if (i == iw) {
220
+ for (int j = 0; j < n; j++) {
221
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
222
+ double actual = hamming_dis(perm[i], perm[j]);
223
+ delta_cost -= w * sqr(wanted - actual);
224
+ double new_actual =
225
+ hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]);
226
+ delta_cost += w * sqr(wanted - new_actual);
227
+ }
228
+ } else if (i == jw) {
229
+ for (int j = 0; j < n; j++) {
230
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
231
+ double actual = hamming_dis(perm[i], perm[j]);
232
+ delta_cost -= w * sqr(wanted - actual);
233
+ double new_actual =
234
+ hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]);
235
+ delta_cost += w * sqr(wanted - new_actual);
236
+ }
237
+ } else {
238
+ int j = iw;
239
+ {
240
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
241
+ double actual = hamming_dis(perm[i], perm[j]);
242
+ delta_cost -= w * sqr(wanted - actual);
243
+ double new_actual = hamming_dis(perm[i], perm[jw]);
244
+ delta_cost += w * sqr(wanted - new_actual);
245
+ }
246
+ j = jw;
247
+ {
248
+ double wanted = target_dis[i * n + j], w = weights[i * n + j];
249
+ double actual = hamming_dis(perm[i], perm[j]);
250
+ delta_cost -= w * sqr(wanted - actual);
251
+ double new_actual = hamming_dis(perm[i], perm[iw]);
252
+ delta_cost += w * sqr(wanted - new_actual);
253
+ }
254
+ }
255
+ }
256
+
257
+ return delta_cost;
258
+ }
259
+
260
+
261
+
262
+ ReproduceWithHammingObjective (
263
+ int nbits,
264
+ const std::vector<double> & dis_table,
265
+ double dis_weight_factor):
266
+ nbits (nbits), dis_weight_factor (dis_weight_factor)
267
+ {
268
+ n = 1 << nbits;
269
+ FAISS_THROW_IF_NOT (dis_table.size() == n * n);
270
+ set_affine_target_dis (dis_table);
271
+ }
272
+
273
+ void set_affine_target_dis (const std::vector<double> & dis_table)
274
+ {
275
+ double sum = 0, sum2 = 0;
276
+ int n2 = n * n;
277
+ for (int i = 0; i < n2; i++) {
278
+ sum += dis_table [i];
279
+ sum2 += dis_table [i] * dis_table [i];
280
+ }
281
+ double mean = sum / n2;
282
+ double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
283
+
284
+ target_dis.resize (n2);
285
+
286
+ for (int i = 0; i < n2; i++) {
287
+ // the mapping function
288
+ double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) +
289
+ nbits / 2;
290
+ target_dis[i] = td;
291
+ // compute a weight
292
+ weights.push_back (dis_weight (td));
293
+ }
294
+
295
+ }
296
+
297
+ ~ReproduceWithHammingObjective() override {}
298
+ };
299
+
300
+ } // anonymous namespace
301
+
302
+ // weihgting of distances: it is more important to reproduce small
303
+ // distances well
304
+ double ReproduceDistancesObjective::dis_weight (double x) const
305
+ {
306
+ return exp (-dis_weight_factor * x);
307
+ }
308
+
309
+
310
+ double ReproduceDistancesObjective::get_source_dis (int i, int j) const
311
+ {
312
+ return source_dis [i * n + j];
313
+ }
314
+
315
+ // cost = quadratic difference between actual distance and Hamming distance
316
+ double ReproduceDistancesObjective::compute_cost (const int *perm) const
317
+ {
318
+ double cost = 0;
319
+ for (int i = 0; i < n; i++) {
320
+ for (int j = 0; j < n; j++) {
321
+ double wanted = target_dis [i * n + j];
322
+ double w = weights [i * n + j];
323
+ double actual = get_source_dis (perm[i], perm[j]);
324
+ cost += w * sqr (wanted - actual);
325
+ }
326
+ }
327
+ return cost;
328
+ }
329
+
330
+ // what would the cost update be if iw and jw were swapped?
331
+ // computed in O(n) instead of O(n^2) for the full re-computation
332
+ double ReproduceDistancesObjective::cost_update(
333
+ const int *perm, int iw, int jw) const
334
+ {
335
+ double delta_cost = 0;
336
+ for (int i = 0; i < n; i++) {
337
+ if (i == iw) {
338
+ for (int j = 0; j < n; j++) {
339
+ double wanted = target_dis [i * n + j],
340
+ w = weights [i * n + j];
341
+ double actual = get_source_dis (perm[i], perm[j]);
342
+ delta_cost -= w * sqr (wanted - actual);
343
+ double new_actual = get_source_dis (
344
+ perm[jw],
345
+ perm[j == iw ? jw : j == jw ? iw : j]);
346
+ delta_cost += w * sqr (wanted - new_actual);
347
+ }
348
+ } else if (i == jw) {
349
+ for (int j = 0; j < n; j++) {
350
+ double wanted = target_dis [i * n + j],
351
+ w = weights [i * n + j];
352
+ double actual = get_source_dis (perm[i], perm[j]);
353
+ delta_cost -= w * sqr (wanted - actual);
354
+ double new_actual = get_source_dis (
355
+ perm[iw],
356
+ perm[j == iw ? jw : j == jw ? iw : j]);
357
+ delta_cost += w * sqr (wanted - new_actual);
358
+ }
359
+ } else {
360
+ int j = iw;
361
+ {
362
+ double wanted = target_dis [i * n + j],
363
+ w = weights [i * n + j];
364
+ double actual = get_source_dis (perm[i], perm[j]);
365
+ delta_cost -= w * sqr (wanted - actual);
366
+ double new_actual = get_source_dis (perm[i], perm[jw]);
367
+ delta_cost += w * sqr (wanted - new_actual);
368
+ }
369
+ j = jw;
370
+ {
371
+ double wanted = target_dis [i * n + j],
372
+ w = weights [i * n + j];
373
+ double actual = get_source_dis (perm[i], perm[j]);
374
+ delta_cost -= w * sqr (wanted - actual);
375
+ double new_actual = get_source_dis (perm[i], perm[iw]);
376
+ delta_cost += w * sqr (wanted - new_actual);
377
+ }
378
+ }
379
+ }
380
+ return delta_cost;
381
+ }
382
+
383
+
384
+
385
+ ReproduceDistancesObjective::ReproduceDistancesObjective (
386
+ int n,
387
+ const double *source_dis_in,
388
+ const double *target_dis_in,
389
+ double dis_weight_factor):
390
+ dis_weight_factor (dis_weight_factor),
391
+ target_dis (target_dis_in)
392
+ {
393
+ this->n = n;
394
+ set_affine_target_dis (source_dis_in);
395
+ }
396
+
397
+ void ReproduceDistancesObjective::compute_mean_stdev (
398
+ const double *tab, size_t n2,
399
+ double *mean_out, double *stddev_out)
400
+ {
401
+ double sum = 0, sum2 = 0;
402
+ for (int i = 0; i < n2; i++) {
403
+ sum += tab [i];
404
+ sum2 += tab [i] * tab [i];
405
+ }
406
+ double mean = sum / n2;
407
+ double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
408
+ *mean_out = mean;
409
+ *stddev_out = stddev;
410
+ }
411
+
412
+ void ReproduceDistancesObjective::set_affine_target_dis (
413
+ const double *source_dis_in)
414
+ {
415
+ int n2 = n * n;
416
+
417
+ double mean_src, stddev_src;
418
+ compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src);
419
+
420
+ double mean_target, stddev_target;
421
+ compute_mean_stdev (target_dis, n2, &mean_target, &stddev_target);
422
+
423
+ printf ("map mean %g std %g -> mean %g std %g\n",
424
+ mean_src, stddev_src, mean_target, stddev_target);
425
+
426
+ source_dis.resize (n2);
427
+ weights.resize (n2);
428
+
429
+ for (int i = 0; i < n2; i++) {
430
+ // the mapping function
431
+ source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src
432
+ * stddev_target + mean_target;
433
+
434
+ // compute a weight
435
+ weights [i] = dis_weight (target_dis[i]);
436
+ }
437
+
438
+ }
439
+
440
+ /****************************************************
441
+ * Cost functions: RankingScore
442
+ ****************************************************/
443
+
444
+ /// Maintains a 3D table of elementary costs.
445
+ /// Accumulates elements based on Hamming distance comparisons
446
+ template <typename Ttab, typename Taccu>
447
+ struct Score3Computer: PermutationObjective {
448
+
449
+ int nc;
450
+
451
+ // cost matrix of size nc * nc *nc
452
+ // n_gt (i,j,k) = count of d_gt(x, y-) < d_gt(x, y+)
453
+ // where x has PQ code i, y- PQ code j and y+ PQ code k
454
+ std::vector<Ttab> n_gt;
455
+
456
+
457
+ /// the cost is a triple loop on the nc * nc * nc matrix of entries.
458
+ ///
459
+ Taccu compute (const int * perm) const
460
+ {
461
+ Taccu accu = 0;
462
+ const Ttab *p = n_gt.data();
463
+ for (int i = 0; i < nc; i++) {
464
+ int ip = perm [i];
465
+ for (int j = 0; j < nc; j++) {
466
+ int jp = perm [j];
467
+ for (int k = 0; k < nc; k++) {
468
+ int kp = perm [k];
469
+ if (hamming_dis (ip, jp) <
470
+ hamming_dis (ip, kp)) {
471
+ accu += *p; // n_gt [ ( i * nc + j) * nc + k];
472
+ }
473
+ p++;
474
+ }
475
+ }
476
+ }
477
+ return accu;
478
+ }
479
+
480
+
481
+ /** cost update if entries iw and jw of the permutation would be
482
+ * swapped.
483
+ *
484
+ * The computation is optimized by avoiding elements in the
485
+ * nc*nc*nc cube that are known not to change. For nc=256, this
486
+ * reduces the nb of cells to visit to about 6/256 th of the
487
+ * cells. Practical speedup is about 8x, and the code is quite
488
+ * complex :-/
489
+ */
490
+ Taccu compute_update (const int *perm, int iw, int jw) const
491
+ {
492
+ assert (iw != jw);
493
+ if (iw > jw) std::swap (iw, jw);
494
+
495
+ Taccu accu = 0;
496
+ const Ttab * n_gt_i = n_gt.data();
497
+ for (int i = 0; i < nc; i++) {
498
+ int ip0 = perm [i];
499
+ int ip = perm [i == iw ? jw : i == jw ? iw : i];
500
+
501
+ //accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
502
+
503
+ accu += update_i_cross (perm, iw, jw,
504
+ ip0, ip, n_gt_i);
505
+
506
+ if (ip != ip0)
507
+ accu += update_i_plane (perm, iw, jw,
508
+ ip0, ip, n_gt_i);
509
+
510
+ n_gt_i += nc * nc;
511
+ }
512
+
513
+ return accu;
514
+ }
515
+
516
+
517
+ Taccu update_i (const int *perm, int iw, int jw,
518
+ int ip0, int ip, const Ttab * n_gt_i) const
519
+ {
520
+ Taccu accu = 0;
521
+ const Ttab *n_gt_ij = n_gt_i;
522
+ for (int j = 0; j < nc; j++) {
523
+ int jp0 = perm[j];
524
+ int jp = perm [j == iw ? jw : j == jw ? iw : j];
525
+ for (int k = 0; k < nc; k++) {
526
+ int kp0 = perm [k];
527
+ int kp = perm [k == iw ? jw : k == jw ? iw : k];
528
+ int ng = n_gt_ij [k];
529
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
530
+ accu += ng;
531
+ }
532
+ if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
533
+ accu -= ng;
534
+ }
535
+ }
536
+ n_gt_ij += nc;
537
+ }
538
+ return accu;
539
+ }
540
+
541
+ // 2 inner loops for the case ip0 != ip
542
+ Taccu update_i_plane (const int *perm, int iw, int jw,
543
+ int ip0, int ip, const Ttab * n_gt_i) const
544
+ {
545
+ Taccu accu = 0;
546
+ const Ttab *n_gt_ij = n_gt_i;
547
+
548
+ for (int j = 0; j < nc; j++) {
549
+ if (j != iw && j != jw) {
550
+ int jp = perm[j];
551
+ for (int k = 0; k < nc; k++) {
552
+ if (k != iw && k != jw) {
553
+ int kp = perm [k];
554
+ Ttab ng = n_gt_ij [k];
555
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
556
+ accu += ng;
557
+ }
558
+ if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) {
559
+ accu -= ng;
560
+ }
561
+ }
562
+ }
563
+ }
564
+ n_gt_ij += nc;
565
+ }
566
+ return accu;
567
+ }
568
+
569
+ /// used for the 8 cells were the 3 indices are swapped
570
+ inline Taccu update_k (const int *perm, int iw, int jw,
571
+ int ip0, int ip, int jp0, int jp,
572
+ int k,
573
+ const Ttab * n_gt_ij) const
574
+ {
575
+ Taccu accu = 0;
576
+ int kp0 = perm [k];
577
+ int kp = perm [k == iw ? jw : k == jw ? iw : k];
578
+ Ttab ng = n_gt_ij [k];
579
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
580
+ accu += ng;
581
+ }
582
+ if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
583
+ accu -= ng;
584
+ }
585
+ return accu;
586
+ }
587
+
588
+ /// compute update on a line of k's, where i and j are swapped
589
+ Taccu update_j_line (const int *perm, int iw, int jw,
590
+ int ip0, int ip, int jp0, int jp,
591
+ const Ttab * n_gt_ij) const
592
+ {
593
+ Taccu accu = 0;
594
+ for (int k = 0; k < nc; k++) {
595
+ if (k == iw || k == jw) continue;
596
+ int kp = perm [k];
597
+ Ttab ng = n_gt_ij [k];
598
+ if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
599
+ accu += ng;
600
+ }
601
+ if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) {
602
+ accu -= ng;
603
+ }
604
+ }
605
+ return accu;
606
+ }
607
+
608
+
609
+ /// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
610
+ Taccu update_i_cross (const int *perm, int iw, int jw,
611
+ int ip0, int ip, const Ttab * n_gt_i) const
612
+ {
613
+ Taccu accu = 0;
614
+ const Ttab *n_gt_ij = n_gt_i;
615
+
616
+ for (int j = 0; j < nc; j++) {
617
+ int jp0 = perm[j];
618
+ int jp = perm [j == iw ? jw : j == jw ? iw : j];
619
+
620
+ accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
621
+ accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
622
+
623
+ if (jp != jp0)
624
+ accu += update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
625
+
626
+ n_gt_ij += nc;
627
+ }
628
+ return accu;
629
+ }
630
+
631
+
632
+ /// PermutationObjective implementeation (just negates the scores
633
+ /// for minimization)
634
+
635
+ double compute_cost(const int* perm) const override {
636
+ return -compute(perm);
637
+ }
638
+
639
+ double cost_update(const int* perm, int iw, int jw) const override {
640
+ double ret = -compute_update(perm, iw, jw);
641
+ return ret;
642
+ }
643
+
644
+ ~Score3Computer() override {}
645
+ };
646
+
647
+
648
+
649
+
650
+
651
+ struct IndirectSort {
652
+ const float *tab;
653
+ bool operator () (int a, int b) {return tab[a] < tab[b]; }
654
+ };
655
+
656
+
657
+
658
+ struct RankingScore2: Score3Computer<float, double> {
659
+ int nbits;
660
+ int nq, nb;
661
+ const uint32_t *qcodes, *bcodes;
662
+ const float *gt_distances;
663
+
664
+ RankingScore2 (int nbits, int nq, int nb,
665
+ const uint32_t *qcodes, const uint32_t *bcodes,
666
+ const float *gt_distances):
667
+ nbits(nbits), nq(nq), nb(nb), qcodes(qcodes),
668
+ bcodes(bcodes), gt_distances(gt_distances)
669
+ {
670
+ n = nc = 1 << nbits;
671
+ n_gt.resize (nc * nc * nc);
672
+ init_n_gt ();
673
+ }
674
+
675
+
676
+ double rank_weight (int r)
677
+ {
678
+ return 1.0 / (r + 1);
679
+ }
680
+
681
+ /// count nb of i, j in a x b st. i < j
682
+ /// a and b should be sorted on input
683
+ /// they are the ranks of j and k respectively.
684
+ /// specific version for diff-of-rank weighting, cannot optimized
685
+ /// with a cumulative table
686
+ double accum_gt_weight_diff (const std::vector<int> & a,
687
+ const std::vector<int> & b)
688
+ {
689
+ int nb = b.size(), na = a.size();
690
+
691
+ double accu = 0;
692
+ int j = 0;
693
+ for (int i = 0; i < na; i++) {
694
+ int ai = a[i];
695
+ while (j < nb && ai >= b[j]) j++;
696
+
697
+ double accu_i = 0;
698
+ for (int k = j; k < b.size(); k++)
699
+ accu_i += rank_weight (b[k] - ai);
700
+
701
+ accu += rank_weight (ai) * accu_i;
702
+
703
+ }
704
+ return accu;
705
+ }
706
+
707
+ void init_n_gt ()
708
+ {
709
+ for (int q = 0; q < nq; q++) {
710
+ const float *gtd = gt_distances + q * nb;
711
+ const uint32_t *cb = bcodes;// all same codes
712
+ float * n_gt_q = & n_gt [qcodes[q] * nc * nc];
713
+
714
+ printf("init gt for q=%d/%d \r", q, nq); fflush(stdout);
715
+
716
+ std::vector<int> rankv (nb);
717
+ int * ranks = rankv.data();
718
+
719
+ // elements in each code bin, ordered by rank within each bin
720
+ std::vector<std::vector<int> > tab (nc);
721
+
722
+ { // build rank table
723
+ IndirectSort s = {gtd};
724
+ for (int j = 0; j < nb; j++) ranks[j] = j;
725
+ std::sort (ranks, ranks + nb, s);
726
+ }
727
+
728
+ for (int rank = 0; rank < nb; rank++) {
729
+ int i = ranks [rank];
730
+ tab [cb[i]].push_back (rank);
731
+ }
732
+
733
+
734
+ // this is very expensive. Any suggestion for improvement
735
+ // welcome.
736
+ for (int i = 0; i < nc; i++) {
737
+ std::vector<int> & di = tab[i];
738
+ for (int j = 0; j < nc; j++) {
739
+ std::vector<int> & dj = tab[j];
740
+ n_gt_q [i * nc + j] += accum_gt_weight_diff (di, dj);
741
+
742
+ }
743
+ }
744
+
745
+ }
746
+
747
+ }
748
+
749
+ };
750
+
751
+
752
+ /*****************************************
753
+ * PolysemousTraining
754
+ ******************************************/
755
+
756
+
757
+
758
+ PolysemousTraining::PolysemousTraining ()
759
+ {
760
+ optimization_type = OT_ReproduceDistances_affine;
761
+ ntrain_permutation = 0;
762
+ dis_weight_factor = log(2);
763
+ }
764
+
765
+
766
+
767
+ void PolysemousTraining::optimize_reproduce_distances (
768
+ ProductQuantizer &pq) const
769
+ {
770
+
771
+ int dsub = pq.dsub;
772
+
773
+ int n = pq.ksub;
774
+ int nbits = pq.nbits;
775
+
776
+ #pragma omp parallel for
777
+ for (int m = 0; m < pq.M; m++) {
778
+ std::vector<double> dis_table;
779
+
780
+ // printf ("Optimizing quantizer %d\n", m);
781
+
782
+ float * centroids = pq.get_centroids (m, 0);
783
+
784
+ for (int i = 0; i < n; i++) {
785
+ for (int j = 0; j < n; j++) {
786
+ dis_table.push_back (fvec_L2sqr (centroids + i * dsub,
787
+ centroids + j * dsub,
788
+ dsub));
789
+ }
790
+ }
791
+
792
+ std::vector<int> perm (n);
793
+ ReproduceWithHammingObjective obj (
794
+ nbits, dis_table,
795
+ dis_weight_factor);
796
+
797
+
798
+ SimulatedAnnealingOptimizer optim (&obj, *this);
799
+
800
+ if (log_pattern.size()) {
801
+ char fname[256];
802
+ snprintf (fname, 256, log_pattern.c_str(), m);
803
+ printf ("opening log file %s\n", fname);
804
+ optim.logfile = fopen (fname, "w");
805
+ FAISS_THROW_IF_NOT_MSG (optim.logfile, "could not open logfile");
806
+ }
807
+ double final_cost = optim.run_optimization (perm.data());
808
+
809
+ if (verbose > 0) {
810
+ printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
811
+ m, optim.init_cost, final_cost);
812
+ }
813
+
814
+ if (log_pattern.size()) fclose (optim.logfile);
815
+
816
+ std::vector<float> centroids_copy;
817
+ for (int i = 0; i < dsub * n; i++)
818
+ centroids_copy.push_back (centroids[i]);
819
+
820
+ for (int i = 0; i < n; i++)
821
+ memcpy (centroids + perm[i] * dsub,
822
+ centroids_copy.data() + i * dsub,
823
+ dsub * sizeof(centroids[0]));
824
+
825
+ }
826
+
827
+ }
828
+
829
+
830
+ void PolysemousTraining::optimize_ranking (
831
+ ProductQuantizer &pq, size_t n, const float *x) const
832
+ {
833
+
834
+ int dsub = pq.dsub;
835
+
836
+ int nbits = pq.nbits;
837
+
838
+ std::vector<uint8_t> all_codes (pq.code_size * n);
839
+
840
+ pq.compute_codes (x, all_codes.data(), n);
841
+
842
+ FAISS_THROW_IF_NOT (pq.nbits == 8);
843
+
844
+ if (n == 0)
845
+ pq.compute_sdc_table ();
846
+
847
+ #pragma omp parallel for
848
+ for (int m = 0; m < pq.M; m++) {
849
+ size_t nq, nb;
850
+ std::vector <uint32_t> codes; // query codes, then db codes
851
+ std::vector <float> gt_distances; // nq * nb matrix of distances
852
+
853
+ if (n > 0) {
854
+ std::vector<float> xtrain (n * dsub);
855
+ for (int i = 0; i < n; i++)
856
+ memcpy (xtrain.data() + i * dsub,
857
+ x + i * pq.d + m * dsub,
858
+ sizeof(float) * dsub);
859
+
860
+ codes.resize (n);
861
+ for (int i = 0; i < n; i++)
862
+ codes [i] = all_codes [i * pq.code_size + m];
863
+
864
+ nq = n / 4; nb = n - nq;
865
+ const float *xq = xtrain.data();
866
+ const float *xb = xq + nq * dsub;
867
+
868
+ gt_distances.resize (nq * nb);
869
+
870
+ pairwise_L2sqr (dsub,
871
+ nq, xq,
872
+ nb, xb,
873
+ gt_distances.data());
874
+ } else {
875
+ nq = nb = pq.ksub;
876
+ codes.resize (2 * nq);
877
+ for (int i = 0; i < nq; i++)
878
+ codes[i] = codes [i + nq] = i;
879
+
880
+ gt_distances.resize (nq * nb);
881
+
882
+ memcpy (gt_distances.data (),
883
+ pq.sdc_table.data () + m * nq * nb,
884
+ sizeof (float) * nq * nb);
885
+ }
886
+
887
+ double t0 = getmillisecs ();
888
+
889
+ PermutationObjective *obj = new RankingScore2 (
890
+ nbits, nq, nb,
891
+ codes.data(), codes.data() + nq,
892
+ gt_distances.data ());
893
+ ScopeDeleter1<PermutationObjective> del (obj);
894
+
895
+ if (verbose > 0) {
896
+ printf(" m=%d, nq=%ld, nb=%ld, intialize RankingScore "
897
+ "in %.3f ms\n",
898
+ m, nq, nb, getmillisecs () - t0);
899
+ }
900
+
901
+ SimulatedAnnealingOptimizer optim (obj, *this);
902
+
903
+ if (log_pattern.size()) {
904
+ char fname[256];
905
+ snprintf (fname, 256, log_pattern.c_str(), m);
906
+ printf ("opening log file %s\n", fname);
907
+ optim.logfile = fopen (fname, "w");
908
+ FAISS_THROW_IF_NOT_FMT (optim.logfile,
909
+ "could not open logfile %s", fname);
910
+ }
911
+
912
+ std::vector<int> perm (pq.ksub);
913
+
914
+ double final_cost = optim.run_optimization (perm.data());
915
+ printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
916
+ m, optim.init_cost, final_cost);
917
+
918
+ if (log_pattern.size()) fclose (optim.logfile);
919
+
920
+ float * centroids = pq.get_centroids (m, 0);
921
+
922
+ std::vector<float> centroids_copy;
923
+ for (int i = 0; i < dsub * pq.ksub; i++)
924
+ centroids_copy.push_back (centroids[i]);
925
+
926
+ for (int i = 0; i < pq.ksub; i++)
927
+ memcpy (centroids + perm[i] * dsub,
928
+ centroids_copy.data() + i * dsub,
929
+ dsub * sizeof(centroids[0]));
930
+
931
+ }
932
+
933
+ }
934
+
935
+
936
+
937
+ void PolysemousTraining::optimize_pq_for_hamming (ProductQuantizer &pq,
938
+ size_t n, const float *x) const
939
+ {
940
+ if (optimization_type == OT_None) {
941
+
942
+ } else if (optimization_type == OT_ReproduceDistances_affine) {
943
+ optimize_reproduce_distances (pq);
944
+ } else {
945
+ optimize_ranking (pq, n, x);
946
+ }
947
+
948
+ pq.compute_sdc_table ();
949
+
950
+ }
951
+
952
+
953
+ } // namespace faiss