faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,170 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <vector>
13
+
14
+ #include <faiss/impl/HNSW.h>
15
+ #include <faiss/IndexFlat.h>
16
+ #include <faiss/IndexPQ.h>
17
+ #include <faiss/IndexScalarQuantizer.h>
18
+ #include <faiss/utils/utils.h>
19
+
20
+
21
+ namespace faiss {
22
+
23
+ struct IndexHNSW;
24
+
25
+ struct ReconstructFromNeighbors {
26
+ typedef Index::idx_t idx_t;
27
+ typedef HNSW::storage_idx_t storage_idx_t;
28
+
29
+ const IndexHNSW & index;
30
+ size_t M; // number of neighbors
31
+ size_t k; // number of codebook entries
32
+ size_t nsq; // number of subvectors
33
+ size_t code_size;
34
+ int k_reorder; // nb to reorder. -1 = all
35
+
36
+ std::vector<float> codebook; // size nsq * k * (M + 1)
37
+
38
+ std::vector<uint8_t> codes; // size ntotal * code_size
39
+ size_t ntotal;
40
+ size_t d, dsub; // derived values
41
+
42
+ explicit ReconstructFromNeighbors(const IndexHNSW& index,
43
+ size_t k=256, size_t nsq=1);
44
+
45
+ /// codes must be added in the correct order and the IndexHNSW
46
+ /// must be populated and sorted
47
+ void add_codes(size_t n, const float *x);
48
+
49
+ size_t compute_distances(size_t n, const idx_t *shortlist,
50
+ const float *query, float *distances) const;
51
+
52
+ /// called by add_codes
53
+ void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const;
54
+
55
+ /// called by compute_distances
56
+ void reconstruct(storage_idx_t i, float *x, float *tmp) const;
57
+
58
+ void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const;
59
+
60
+ /// get the M+1 -by-d table for neighbor coordinates for vector i
61
+ void get_neighbor_table(storage_idx_t i, float *out) const;
62
+
63
+ };
64
+
65
+
66
+ /** The HNSW index is a normal random-access index with a HNSW
67
+ * link structure built on top */
68
+
69
+ struct IndexHNSW : Index {
70
+
71
+ typedef HNSW::storage_idx_t storage_idx_t;
72
+
73
+ // the link strcuture
74
+ HNSW hnsw;
75
+
76
+ // the sequential storage
77
+ bool own_fields;
78
+ Index *storage;
79
+
80
+ ReconstructFromNeighbors *reconstruct_from_neighbors;
81
+
82
+ explicit IndexHNSW (int d = 0, int M = 32);
83
+ explicit IndexHNSW (Index *storage, int M = 32);
84
+
85
+ ~IndexHNSW() override;
86
+
87
+ void add(idx_t n, const float *x) override;
88
+
89
+ /// Trains the storage if needed
90
+ void train(idx_t n, const float* x) override;
91
+
92
+ /// entry point for search
93
+ void search (idx_t n, const float *x, idx_t k,
94
+ float *distances, idx_t *labels) const override;
95
+
96
+ void reconstruct(idx_t key, float* recons) const override;
97
+
98
+ void reset () override;
99
+
100
+ void shrink_level_0_neighbors(int size);
101
+
102
+ /** Perform search only on level 0, given the starting points for
103
+ * each vertex.
104
+ *
105
+ * @param search_type 1:perform one search per nprobe, 2: enqueue
106
+ * all entry points
107
+ */
108
+ void search_level_0(idx_t n, const float *x, idx_t k,
109
+ const storage_idx_t *nearest, const float *nearest_d,
110
+ float *distances, idx_t *labels, int nprobe = 1,
111
+ int search_type = 1) const;
112
+
113
+ /// alternative graph building
114
+ void init_level_0_from_knngraph(
115
+ int k, const float *D, const idx_t *I);
116
+
117
+ /// alternative graph building
118
+ void init_level_0_from_entry_points(
119
+ int npt, const storage_idx_t *points,
120
+ const storage_idx_t *nearests);
121
+
122
+ // reorder links from nearest to farthest
123
+ void reorder_links();
124
+
125
+ void link_singletons();
126
+ };
127
+
128
+
129
+ /** Flat index topped with with a HNSW structure to access elements
130
+ * more efficiently.
131
+ */
132
+
133
+ struct IndexHNSWFlat : IndexHNSW {
134
+ IndexHNSWFlat();
135
+ IndexHNSWFlat(int d, int M);
136
+ };
137
+
138
+ /** PQ index topped with with a HNSW structure to access elements
139
+ * more efficiently.
140
+ */
141
+ struct IndexHNSWPQ : IndexHNSW {
142
+ IndexHNSWPQ();
143
+ IndexHNSWPQ(int d, int pq_m, int M);
144
+ void train(idx_t n, const float* x) override;
145
+ };
146
+
147
+ /** SQ index topped with with a HNSW structure to access elements
148
+ * more efficiently.
149
+ */
150
+ struct IndexHNSWSQ : IndexHNSW {
151
+ IndexHNSWSQ();
152
+ IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M);
153
+ };
154
+
155
+ /** 2-level code structure with fast random access
156
+ */
157
+ struct IndexHNSW2Level : IndexHNSW {
158
+ IndexHNSW2Level();
159
+ IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M);
160
+
161
+ void flip_to_ivf();
162
+
163
+ /// entry point for search
164
+ void search (idx_t n, const float *x, idx_t k,
165
+ float *distances, idx_t *labels) const override;
166
+
167
+ };
168
+
169
+
170
+ } // namespace faiss
@@ -0,0 +1,909 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexIVF.h>
11
+
12
+
13
+ #include <omp.h>
14
+
15
+ #include <cstdio>
16
+ #include <memory>
17
+
18
+ #include <faiss/utils/utils.h>
19
+ #include <faiss/utils/hamming.h>
20
+
21
+ #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/IndexFlat.h>
23
+ #include <faiss/impl/AuxIndexStructures.h>
24
+
25
+ namespace faiss {
26
+
27
+ using ScopedIds = InvertedLists::ScopedIds;
28
+ using ScopedCodes = InvertedLists::ScopedCodes;
29
+
30
+ /*****************************************
31
+ * Level1Quantizer implementation
32
+ ******************************************/
33
+
34
+
35
+ Level1Quantizer::Level1Quantizer (Index * quantizer, size_t nlist):
36
+ quantizer (quantizer),
37
+ nlist (nlist),
38
+ quantizer_trains_alone (0),
39
+ own_fields (false),
40
+ clustering_index (nullptr)
41
+ {
42
+ // here we set a low # iterations because this is typically used
43
+ // for large clusterings (nb this is not used for the MultiIndex,
44
+ // for which quantizer_trains_alone = true)
45
+ cp.niter = 10;
46
+ }
47
+
48
+ Level1Quantizer::Level1Quantizer ():
49
+ quantizer (nullptr),
50
+ nlist (0),
51
+ quantizer_trains_alone (0), own_fields (false),
52
+ clustering_index (nullptr)
53
+ {}
54
+
55
+ Level1Quantizer::~Level1Quantizer ()
56
+ {
57
+ if (own_fields) delete quantizer;
58
+ }
59
+
60
+ void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
61
+ {
62
+ size_t d = quantizer->d;
63
+ if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
64
+ if (verbose)
65
+ printf ("IVF quantizer does not need training.\n");
66
+ } else if (quantizer_trains_alone == 1) {
67
+ if (verbose)
68
+ printf ("IVF quantizer trains alone...\n");
69
+ quantizer->train (n, x);
70
+ quantizer->verbose = verbose;
71
+ FAISS_THROW_IF_NOT_MSG (quantizer->ntotal == nlist,
72
+ "nlist not consistent with quantizer size");
73
+ } else if (quantizer_trains_alone == 0) {
74
+ if (verbose)
75
+ printf ("Training level-1 quantizer on %ld vectors in %ldD\n",
76
+ n, d);
77
+
78
+ Clustering clus (d, nlist, cp);
79
+ quantizer->reset();
80
+ if (clustering_index) {
81
+ clus.train (n, x, *clustering_index);
82
+ quantizer->add (nlist, clus.centroids.data());
83
+ } else {
84
+ clus.train (n, x, *quantizer);
85
+ }
86
+ quantizer->is_trained = true;
87
+ } else if (quantizer_trains_alone == 2) {
88
+ if (verbose)
89
+ printf (
90
+ "Training L2 quantizer on %ld vectors in %ldD%s\n",
91
+ n, d,
92
+ clustering_index ? "(user provided index)" : "");
93
+ FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
94
+ Clustering clus (d, nlist, cp);
95
+ if (!clustering_index) {
96
+ IndexFlatL2 assigner (d);
97
+ clus.train(n, x, assigner);
98
+ } else {
99
+ clus.train(n, x, *clustering_index);
100
+ }
101
+ if (verbose)
102
+ printf ("Adding centroids to quantizer\n");
103
+ quantizer->add (nlist, clus.centroids.data());
104
+ }
105
+ }
106
+
107
+ size_t Level1Quantizer::coarse_code_size () const
108
+ {
109
+ size_t nl = nlist - 1;
110
+ size_t nbyte = 0;
111
+ while (nl > 0) {
112
+ nbyte ++;
113
+ nl >>= 8;
114
+ }
115
+ return nbyte;
116
+ }
117
+
118
+ void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
119
+ {
120
+ // little endian
121
+ size_t nl = nlist - 1;
122
+ while (nl > 0) {
123
+ *code++ = list_no & 0xff;
124
+ list_no >>= 8;
125
+ nl >>= 8;
126
+ }
127
+ }
128
+
129
+ Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
130
+ {
131
+ size_t nl = nlist - 1;
132
+ int64_t list_no = 0;
133
+ int nbit = 0;
134
+ while (nl > 0) {
135
+ list_no |= int64_t(*code++) << nbit;
136
+ nbit += 8;
137
+ nl >>= 8;
138
+ }
139
+ FAISS_THROW_IF_NOT (list_no >= 0 && list_no < nlist);
140
+ return list_no;
141
+ }
142
+
143
+
144
+
145
+ /*****************************************
146
+ * IndexIVF implementation
147
+ ******************************************/
148
+
149
+
150
+ IndexIVF::IndexIVF (Index * quantizer, size_t d,
151
+ size_t nlist, size_t code_size,
152
+ MetricType metric):
153
+ Index (d, metric),
154
+ Level1Quantizer (quantizer, nlist),
155
+ invlists (new ArrayInvertedLists (nlist, code_size)),
156
+ own_invlists (true),
157
+ code_size (code_size),
158
+ nprobe (1),
159
+ max_codes (0),
160
+ parallel_mode (0),
161
+ maintain_direct_map (false)
162
+ {
163
+ FAISS_THROW_IF_NOT (d == quantizer->d);
164
+ is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
165
+ // Spherical by default if the metric is inner_product
166
+ if (metric_type == METRIC_INNER_PRODUCT) {
167
+ cp.spherical = true;
168
+ }
169
+
170
+ }
171
+
172
+ IndexIVF::IndexIVF ():
173
+ invlists (nullptr), own_invlists (false),
174
+ code_size (0),
175
+ nprobe (1), max_codes (0), parallel_mode (0),
176
+ maintain_direct_map (false)
177
+ {}
178
+
179
+ void IndexIVF::add (idx_t n, const float * x)
180
+ {
181
+ add_with_ids (n, x, nullptr);
182
+ }
183
+
184
+
185
+ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
186
+ {
187
+ // do some blocking to avoid excessive allocs
188
+ idx_t bs = 65536;
189
+ if (n > bs) {
190
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
191
+ idx_t i1 = std::min (n, i0 + bs);
192
+ if (verbose) {
193
+ printf(" IndexIVF::add_with_ids %ld:%ld\n", i0, i1);
194
+ }
195
+ add_with_ids (i1 - i0, x + i0 * d,
196
+ xids ? xids + i0 : nullptr);
197
+ }
198
+ return;
199
+ }
200
+
201
+ FAISS_THROW_IF_NOT (is_trained);
202
+ std::unique_ptr<idx_t []> idx(new idx_t[n]);
203
+ quantizer->assign (n, x, idx.get());
204
+ size_t nadd = 0, nminus1 = 0;
205
+
206
+ for (size_t i = 0; i < n; i++) {
207
+ if (idx[i] < 0) nminus1++;
208
+ }
209
+
210
+ std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
211
+ encode_vectors (n, x, idx.get(), flat_codes.get());
212
+
213
+ #pragma omp parallel reduction(+: nadd)
214
+ {
215
+ int nt = omp_get_num_threads();
216
+ int rank = omp_get_thread_num();
217
+
218
+ // each thread takes care of a subset of lists
219
+ for (size_t i = 0; i < n; i++) {
220
+ idx_t list_no = idx [i];
221
+ if (list_no >= 0 && list_no % nt == rank) {
222
+ idx_t id = xids ? xids[i] : ntotal + i;
223
+ invlists->add_entry (list_no, id,
224
+ flat_codes.get() + i * code_size);
225
+ nadd++;
226
+ }
227
+ }
228
+ }
229
+
230
+ if (verbose) {
231
+ printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1);
232
+ }
233
+
234
+ ntotal += n;
235
+ }
236
+
237
+
238
+ void IndexIVF::make_direct_map (bool new_maintain_direct_map)
239
+ {
240
+ // nothing to do
241
+ if (new_maintain_direct_map == maintain_direct_map)
242
+ return;
243
+
244
+ if (new_maintain_direct_map) {
245
+ direct_map.resize (ntotal, -1);
246
+ for (size_t key = 0; key < nlist; key++) {
247
+ size_t list_size = invlists->list_size (key);
248
+ ScopedIds idlist (invlists, key);
249
+
250
+ for (long ofs = 0; ofs < list_size; ofs++) {
251
+ FAISS_THROW_IF_NOT_MSG (
252
+ 0 <= idlist [ofs] && idlist[ofs] < ntotal,
253
+ "direct map supported only for seuquential ids");
254
+ direct_map [idlist [ofs]] = key << 32 | ofs;
255
+ }
256
+ }
257
+ } else {
258
+ direct_map.clear ();
259
+ }
260
+ maintain_direct_map = new_maintain_direct_map;
261
+ }
262
+
263
+
264
+ void IndexIVF::search (idx_t n, const float *x, idx_t k,
265
+ float *distances, idx_t *labels) const
266
+ {
267
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
268
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
269
+
270
+ double t0 = getmillisecs();
271
+ quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
272
+ indexIVF_stats.quantization_time += getmillisecs() - t0;
273
+
274
+ t0 = getmillisecs();
275
+ invlists->prefetch_lists (idx.get(), n * nprobe);
276
+
277
+ search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
278
+ distances, labels, false);
279
+ indexIVF_stats.search_time += getmillisecs() - t0;
280
+ }
281
+
282
+
283
+
284
+ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
285
+ const idx_t *keys,
286
+ const float *coarse_dis ,
287
+ float *distances, idx_t *labels,
288
+ bool store_pairs,
289
+ const IVFSearchParameters *params) const
290
+ {
291
+ long nprobe = params ? params->nprobe : this->nprobe;
292
+ long max_codes = params ? params->max_codes : this->max_codes;
293
+
294
+ size_t nlistv = 0, ndis = 0, nheap = 0;
295
+
296
+ using HeapForIP = CMin<float, idx_t>;
297
+ using HeapForL2 = CMax<float, idx_t>;
298
+
299
+ bool interrupt = false;
300
+
301
+ // don't start parallel section if single query
302
+ bool do_parallel =
303
+ parallel_mode == 0 ? n > 1 :
304
+ parallel_mode == 1 ? nprobe > 1 :
305
+ nprobe * n > 1;
306
+
307
+ #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
308
+ {
309
+ InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
310
+ ScopeDeleter1<InvertedListScanner> del(scanner);
311
+
312
+ /*****************************************************
313
+ * Depending on parallel_mode, there are two possible ways
314
+ * to organize the search. Here we define local functions
315
+ * that are in common between the two
316
+ ******************************************************/
317
+
318
+ // intialize + reorder a result heap
319
+
320
+ auto init_result = [&](float *simi, idx_t *idxi) {
321
+ if (metric_type == METRIC_INNER_PRODUCT) {
322
+ heap_heapify<HeapForIP> (k, simi, idxi);
323
+ } else {
324
+ heap_heapify<HeapForL2> (k, simi, idxi);
325
+ }
326
+ };
327
+
328
+ auto reorder_result = [&] (float *simi, idx_t *idxi) {
329
+ if (metric_type == METRIC_INNER_PRODUCT) {
330
+ heap_reorder<HeapForIP> (k, simi, idxi);
331
+ } else {
332
+ heap_reorder<HeapForL2> (k, simi, idxi);
333
+ }
334
+ };
335
+
336
+ // single list scan using the current scanner (with query
337
+ // set porperly) and storing results in simi and idxi
338
+ auto scan_one_list = [&] (idx_t key, float coarse_dis_i,
339
+ float *simi, idx_t *idxi) {
340
+
341
+ if (key < 0) {
342
+ // not enough centroids for multiprobe
343
+ return (size_t)0;
344
+ }
345
+ FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
346
+ "Invalid key=%ld nlist=%ld\n",
347
+ key, nlist);
348
+
349
+ size_t list_size = invlists->list_size(key);
350
+
351
+ // don't waste time on empty lists
352
+ if (list_size == 0) {
353
+ return (size_t)0;
354
+ }
355
+
356
+ scanner->set_list (key, coarse_dis_i);
357
+
358
+ nlistv++;
359
+
360
+ InvertedLists::ScopedCodes scodes (invlists, key);
361
+
362
+ std::unique_ptr<InvertedLists::ScopedIds> sids;
363
+ const Index::idx_t * ids = nullptr;
364
+
365
+ if (!store_pairs) {
366
+ sids.reset (new InvertedLists::ScopedIds (invlists, key));
367
+ ids = sids->get();
368
+ }
369
+
370
+ nheap += scanner->scan_codes (list_size, scodes.get(),
371
+ ids, simi, idxi, k);
372
+
373
+ return list_size;
374
+ };
375
+
376
+ /****************************************************
377
+ * Actual loops, depending on parallel_mode
378
+ ****************************************************/
379
+
380
+ if (parallel_mode == 0) {
381
+
382
+ #pragma omp for
383
+ for (size_t i = 0; i < n; i++) {
384
+
385
+ if (interrupt) {
386
+ continue;
387
+ }
388
+
389
+ // loop over queries
390
+ scanner->set_query (x + i * d);
391
+ float * simi = distances + i * k;
392
+ idx_t * idxi = labels + i * k;
393
+
394
+ init_result (simi, idxi);
395
+
396
+ long nscan = 0;
397
+
398
+ // loop over probes
399
+ for (size_t ik = 0; ik < nprobe; ik++) {
400
+
401
+ nscan += scan_one_list (
402
+ keys [i * nprobe + ik],
403
+ coarse_dis[i * nprobe + ik],
404
+ simi, idxi
405
+ );
406
+
407
+ if (max_codes && nscan >= max_codes) {
408
+ break;
409
+ }
410
+ }
411
+
412
+ ndis += nscan;
413
+ reorder_result (simi, idxi);
414
+
415
+ if (InterruptCallback::is_interrupted ()) {
416
+ interrupt = true;
417
+ }
418
+
419
+ } // parallel for
420
+ } else if (parallel_mode == 1) {
421
+ std::vector <idx_t> local_idx (k);
422
+ std::vector <float> local_dis (k);
423
+
424
+ for (size_t i = 0; i < n; i++) {
425
+ scanner->set_query (x + i * d);
426
+ init_result (local_dis.data(), local_idx.data());
427
+
428
+ #pragma omp for schedule(dynamic)
429
+ for (size_t ik = 0; ik < nprobe; ik++) {
430
+ ndis += scan_one_list
431
+ (keys [i * nprobe + ik],
432
+ coarse_dis[i * nprobe + ik],
433
+ local_dis.data(), local_idx.data());
434
+
435
+ // can't do the test on max_codes
436
+ }
437
+ // merge thread-local results
438
+
439
+ float * simi = distances + i * k;
440
+ idx_t * idxi = labels + i * k;
441
+ #pragma omp single
442
+ init_result (simi, idxi);
443
+
444
+ #pragma omp barrier
445
+ #pragma omp critical
446
+ {
447
+ if (metric_type == METRIC_INNER_PRODUCT) {
448
+ heap_addn<HeapForIP>
449
+ (k, simi, idxi,
450
+ local_dis.data(), local_idx.data(), k);
451
+ } else {
452
+ heap_addn<HeapForL2>
453
+ (k, simi, idxi,
454
+ local_dis.data(), local_idx.data(), k);
455
+ }
456
+ }
457
+ #pragma omp barrier
458
+ #pragma omp single
459
+ reorder_result (simi, idxi);
460
+ }
461
+ } else {
462
+ FAISS_THROW_FMT ("parallel_mode %d not supported\n",
463
+ parallel_mode);
464
+ }
465
+ } // parallel section
466
+
467
+ if (interrupt) {
468
+ FAISS_THROW_MSG ("computation interrupted");
469
+ }
470
+
471
+ indexIVF_stats.nq += n;
472
+ indexIVF_stats.nlist += nlistv;
473
+ indexIVF_stats.ndis += ndis;
474
+ indexIVF_stats.nheap_updates += nheap;
475
+
476
+ }
477
+
478
+
479
+
480
+
481
+ void IndexIVF::range_search (idx_t nx, const float *x, float radius,
482
+ RangeSearchResult *result) const
483
+ {
484
+ std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
485
+ std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
486
+
487
+ double t0 = getmillisecs();
488
+ quantizer->search (nx, x, nprobe, coarse_dis.get (), keys.get ());
489
+ indexIVF_stats.quantization_time += getmillisecs() - t0;
490
+
491
+ t0 = getmillisecs();
492
+ invlists->prefetch_lists (keys.get(), nx * nprobe);
493
+
494
+ range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
495
+ result);
496
+
497
+ indexIVF_stats.search_time += getmillisecs() - t0;
498
+ }
499
+
500
+ void IndexIVF::range_search_preassigned (
501
+ idx_t nx, const float *x, float radius,
502
+ const idx_t *keys, const float *coarse_dis,
503
+ RangeSearchResult *result) const
504
+ {
505
+
506
+ size_t nlistv = 0, ndis = 0;
507
+ bool store_pairs = false;
508
+
509
+ std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
510
+
511
+ #pragma omp parallel reduction(+: nlistv, ndis)
512
+ {
513
+ RangeSearchPartialResult pres(result);
514
+ std::unique_ptr<InvertedListScanner> scanner
515
+ (get_InvertedListScanner(store_pairs));
516
+ FAISS_THROW_IF_NOT (scanner.get ());
517
+ all_pres[omp_get_thread_num()] = &pres;
518
+
519
+ // prepare the list scanning function
520
+
521
+ auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) {
522
+
523
+ idx_t key = keys[i * nprobe + ik]; /* select the list */
524
+ if (key < 0) return;
525
+ FAISS_THROW_IF_NOT_FMT (
526
+ key < (idx_t) nlist,
527
+ "Invalid key=%ld at ik=%ld nlist=%ld\n",
528
+ key, ik, nlist);
529
+ const size_t list_size = invlists->list_size(key);
530
+
531
+ if (list_size == 0) return;
532
+
533
+ InvertedLists::ScopedCodes scodes (invlists, key);
534
+ InvertedLists::ScopedIds ids (invlists, key);
535
+
536
+ scanner->set_list (key, coarse_dis[i * nprobe + ik]);
537
+ nlistv++;
538
+ ndis += list_size;
539
+ scanner->scan_codes_range (list_size, scodes.get(),
540
+ ids.get(), radius, qres);
541
+ };
542
+
543
+ if (parallel_mode == 0) {
544
+
545
+ #pragma omp for
546
+ for (size_t i = 0; i < nx; i++) {
547
+ scanner->set_query (x + i * d);
548
+
549
+ RangeQueryResult & qres = pres.new_result (i);
550
+
551
+ for (size_t ik = 0; ik < nprobe; ik++) {
552
+ scan_list_func (i, ik, qres);
553
+ }
554
+
555
+ }
556
+
557
+ } else if (parallel_mode == 1) {
558
+
559
+ for (size_t i = 0; i < nx; i++) {
560
+ scanner->set_query (x + i * d);
561
+
562
+ RangeQueryResult & qres = pres.new_result (i);
563
+
564
+ #pragma omp for schedule(dynamic)
565
+ for (size_t ik = 0; ik < nprobe; ik++) {
566
+ scan_list_func (i, ik, qres);
567
+ }
568
+ }
569
+ } else if (parallel_mode == 2) {
570
+ std::vector<RangeQueryResult *> all_qres (nx);
571
+ RangeQueryResult *qres = nullptr;
572
+
573
+ #pragma omp for schedule(dynamic)
574
+ for (size_t iik = 0; iik < nx * nprobe; iik++) {
575
+ size_t i = iik / nprobe;
576
+ size_t ik = iik % nprobe;
577
+ if (qres == nullptr || qres->qno != i) {
578
+ FAISS_ASSERT (!qres || i > qres->qno);
579
+ qres = &pres.new_result (i);
580
+ scanner->set_query (x + i * d);
581
+ }
582
+ scan_list_func (i, ik, *qres);
583
+ }
584
+ } else {
585
+ FAISS_THROW_FMT ("parallel_mode %d not supported\n", parallel_mode);
586
+ }
587
+ if (parallel_mode == 0) {
588
+ pres.finalize ();
589
+ } else {
590
+ #pragma omp barrier
591
+ #pragma omp single
592
+ RangeSearchPartialResult::merge (all_pres, false);
593
+ #pragma omp barrier
594
+
595
+ }
596
+ }
597
+ indexIVF_stats.nq += nx;
598
+ indexIVF_stats.nlist += nlistv;
599
+ indexIVF_stats.ndis += ndis;
600
+ }
601
+
602
+
603
+ InvertedListScanner *IndexIVF::get_InvertedListScanner (
604
+ bool /*store_pairs*/) const
605
+ {
606
+ return nullptr;
607
+ }
608
+
609
+ void IndexIVF::reconstruct (idx_t key, float* recons) const
610
+ {
611
+ FAISS_THROW_IF_NOT_MSG (direct_map.size() == ntotal,
612
+ "direct map is not initialized");
613
+ FAISS_THROW_IF_NOT_MSG (key >= 0 && key < direct_map.size(),
614
+ "invalid key");
615
+ idx_t list_no = direct_map[key] >> 32;
616
+ idx_t offset = direct_map[key] & 0xffffffff;
617
+ reconstruct_from_offset (list_no, offset, recons);
618
+ }
619
+
620
+
621
+ void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
622
+ {
623
+ FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
624
+
625
+ for (idx_t list_no = 0; list_no < nlist; list_no++) {
626
+ size_t list_size = invlists->list_size (list_no);
627
+ ScopedIds idlist (invlists, list_no);
628
+
629
+ for (idx_t offset = 0; offset < list_size; offset++) {
630
+ idx_t id = idlist[offset];
631
+ if (!(id >= i0 && id < i0 + ni)) {
632
+ continue;
633
+ }
634
+
635
+ float* reconstructed = recons + (id - i0) * d;
636
+ reconstruct_from_offset (list_no, offset, reconstructed);
637
+ }
638
+ }
639
+ }
640
+
641
+
642
+ /* standalone codec interface */
643
+ size_t IndexIVF::sa_code_size () const
644
+ {
645
+ size_t coarse_size = coarse_code_size();
646
+ return code_size + coarse_size;
647
+ }
648
+
649
+ void IndexIVF::sa_encode (idx_t n, const float *x,
650
+ uint8_t *bytes) const
651
+ {
652
+ FAISS_THROW_IF_NOT (is_trained);
653
+ std::unique_ptr<int64_t []> idx (new int64_t [n]);
654
+ quantizer->assign (n, x, idx.get());
655
+ encode_vectors (n, x, idx.get(), bytes, true);
656
+ }
657
+
658
+
659
+ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
660
+ float *distances, idx_t *labels,
661
+ float *recons) const
662
+ {
663
+ idx_t * idx = new idx_t [n * nprobe];
664
+ ScopeDeleter<idx_t> del (idx);
665
+ float * coarse_dis = new float [n * nprobe];
666
+ ScopeDeleter<float> del2 (coarse_dis);
667
+
668
+ quantizer->search (n, x, nprobe, coarse_dis, idx);
669
+
670
+ invlists->prefetch_lists (idx, n * nprobe);
671
+
672
+ // search_preassigned() with `store_pairs` enabled to obtain the list_no
673
+ // and offset into `codes` for reconstruction
674
+ search_preassigned (n, x, k, idx, coarse_dis,
675
+ distances, labels, true /* store_pairs */);
676
+ for (idx_t i = 0; i < n; ++i) {
677
+ for (idx_t j = 0; j < k; ++j) {
678
+ idx_t ij = i * k + j;
679
+ idx_t key = labels[ij];
680
+ float* reconstructed = recons + ij * d;
681
+ if (key < 0) {
682
+ // Fill with NaNs
683
+ memset(reconstructed, -1, sizeof(*reconstructed) * d);
684
+ } else {
685
+ int list_no = key >> 32;
686
+ int offset = key & 0xffffffff;
687
+
688
+ // Update label to the actual id
689
+ labels[ij] = invlists->get_single_id (list_no, offset);
690
+
691
+ reconstruct_from_offset (list_no, offset, reconstructed);
692
+ }
693
+ }
694
+ }
695
+ }
696
+
697
+ void IndexIVF::reconstruct_from_offset(
698
+ int64_t /*list_no*/,
699
+ int64_t /*offset*/,
700
+ float* /*recons*/) const {
701
+ FAISS_THROW_MSG ("reconstruct_from_offset not implemented");
702
+ }
703
+
704
+ void IndexIVF::reset ()
705
+ {
706
+ direct_map.clear ();
707
+ invlists->reset ();
708
+ ntotal = 0;
709
+ }
710
+
711
+
712
+ size_t IndexIVF::remove_ids (const IDSelector & sel)
713
+ {
714
+ FAISS_THROW_IF_NOT_MSG (!maintain_direct_map,
715
+ "direct map remove not implemented");
716
+
717
+ std::vector<idx_t> toremove(nlist);
718
+
719
+ #pragma omp parallel for
720
+ for (idx_t i = 0; i < nlist; i++) {
721
+ idx_t l0 = invlists->list_size (i), l = l0, j = 0;
722
+ ScopedIds idsi (invlists, i);
723
+ while (j < l) {
724
+ if (sel.is_member (idsi[j])) {
725
+ l--;
726
+ invlists->update_entry (
727
+ i, j,
728
+ invlists->get_single_id (i, l),
729
+ ScopedCodes (invlists, i, l).get());
730
+ } else {
731
+ j++;
732
+ }
733
+ }
734
+ toremove[i] = l0 - l;
735
+ }
736
+ // this will not run well in parallel on ondisk because of possible shrinks
737
+ size_t nremove = 0;
738
+ for (idx_t i = 0; i < nlist; i++) {
739
+ if (toremove[i] > 0) {
740
+ nremove += toremove[i];
741
+ invlists->resize(
742
+ i, invlists->list_size(i) - toremove[i]);
743
+ }
744
+ }
745
+ ntotal -= nremove;
746
+ return nremove;
747
+ }
748
+
749
+
750
+
751
+
752
+ void IndexIVF::train (idx_t n, const float *x)
753
+ {
754
+ if (verbose)
755
+ printf ("Training level-1 quantizer\n");
756
+
757
+ train_q1 (n, x, verbose, metric_type);
758
+
759
+ if (verbose)
760
+ printf ("Training IVF residual\n");
761
+
762
+ train_residual (n, x);
763
+ is_trained = true;
764
+
765
+ }
766
+
767
+ void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
768
+ if (verbose)
769
+ printf("IndexIVF: no residual training\n");
770
+ // does nothing by default
771
+ }
772
+
773
+
774
+ void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
775
+ {
776
+ // minimal sanity checks
777
+ FAISS_THROW_IF_NOT (other.d == d);
778
+ FAISS_THROW_IF_NOT (other.nlist == nlist);
779
+ FAISS_THROW_IF_NOT (other.code_size == code_size);
780
+ FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
781
+ "can only merge indexes of the same type");
782
+ }
783
+
784
+
785
+ void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
786
+ {
787
+ check_compatible_for_merge (other);
788
+ FAISS_THROW_IF_NOT_MSG ((!maintain_direct_map &&
789
+ !other.maintain_direct_map),
790
+ "direct map copy not implemented");
791
+
792
+ invlists->merge_from (other.invlists, add_id);
793
+
794
+ ntotal += other.ntotal;
795
+ other.ntotal = 0;
796
+ }
797
+
798
+
799
+ void IndexIVF::replace_invlists (InvertedLists *il, bool own)
800
+ {
801
+ if (own_invlists) {
802
+ delete invlists;
803
+ }
804
+ // FAISS_THROW_IF_NOT (ntotal == 0);
805
+ if (il) {
806
+ FAISS_THROW_IF_NOT (il->nlist == nlist &&
807
+ il->code_size == code_size);
808
+ }
809
+ invlists = il;
810
+ own_invlists = own;
811
+ }
812
+
813
+
814
+ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
815
+ idx_t a1, idx_t a2) const
816
+ {
817
+
818
+ FAISS_THROW_IF_NOT (nlist == other.nlist);
819
+ FAISS_THROW_IF_NOT (code_size == other.code_size);
820
+ FAISS_THROW_IF_NOT (!other.maintain_direct_map);
821
+ FAISS_THROW_IF_NOT_FMT (
822
+ subset_type == 0 || subset_type == 1 || subset_type == 2,
823
+ "subset type %d not implemented", subset_type);
824
+
825
+ size_t accu_n = 0;
826
+ size_t accu_a1 = 0;
827
+ size_t accu_a2 = 0;
828
+
829
+ InvertedLists *oivf = other.invlists;
830
+
831
+ for (idx_t list_no = 0; list_no < nlist; list_no++) {
832
+ size_t n = invlists->list_size (list_no);
833
+ ScopedIds ids_in (invlists, list_no);
834
+
835
+ if (subset_type == 0) {
836
+ for (idx_t i = 0; i < n; i++) {
837
+ idx_t id = ids_in[i];
838
+ if (a1 <= id && id < a2) {
839
+ oivf->add_entry (list_no,
840
+ invlists->get_single_id (list_no, i),
841
+ ScopedCodes (invlists, list_no, i).get());
842
+ other.ntotal++;
843
+ }
844
+ }
845
+ } else if (subset_type == 1) {
846
+ for (idx_t i = 0; i < n; i++) {
847
+ idx_t id = ids_in[i];
848
+ if (id % a1 == a2) {
849
+ oivf->add_entry (list_no,
850
+ invlists->get_single_id (list_no, i),
851
+ ScopedCodes (invlists, list_no, i).get());
852
+ other.ntotal++;
853
+ }
854
+ }
855
+ } else if (subset_type == 2) {
856
+ // see what is allocated to a1 and to a2
857
+ size_t next_accu_n = accu_n + n;
858
+ size_t next_accu_a1 = next_accu_n * a1 / ntotal;
859
+ size_t i1 = next_accu_a1 - accu_a1;
860
+ size_t next_accu_a2 = next_accu_n * a2 / ntotal;
861
+ size_t i2 = next_accu_a2 - accu_a2;
862
+
863
+ for (idx_t i = i1; i < i2; i++) {
864
+ oivf->add_entry (list_no,
865
+ invlists->get_single_id (list_no, i),
866
+ ScopedCodes (invlists, list_no, i).get());
867
+ }
868
+
869
+ other.ntotal += i2 - i1;
870
+ accu_a1 = next_accu_a1;
871
+ accu_a2 = next_accu_a2;
872
+ }
873
+ accu_n += n;
874
+ }
875
+ FAISS_ASSERT(accu_n == ntotal);
876
+
877
+ }
878
+
879
+
880
+
881
+
882
+ IndexIVF::~IndexIVF()
883
+ {
884
+ if (own_invlists) {
885
+ delete invlists;
886
+ }
887
+ }
888
+
889
+
890
+ void IndexIVFStats::reset()
891
+ {
892
+ memset ((void*)this, 0, sizeof (*this));
893
+ }
894
+
895
+
896
+ IndexIVFStats indexIVF_stats;
897
+
898
+ void InvertedListScanner::scan_codes_range (size_t ,
899
+ const uint8_t *,
900
+ const idx_t *,
901
+ float ,
902
+ RangeQueryResult &) const
903
+ {
904
+ FAISS_THROW_MSG ("scan_codes_range not implemented");
905
+ }
906
+
907
+
908
+
909
+ } // namespace faiss