faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,199 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_INDEX_PQ_H
11
+ #define FAISS_INDEX_PQ_H
12
+
13
+ #include <stdint.h>
14
+
15
+ #include <vector>
16
+
17
+ #include <faiss/Index.h>
18
+ #include <faiss/impl/ProductQuantizer.h>
19
+ #include <faiss/impl/PolysemousTraining.h>
20
+
21
+ namespace faiss {
22
+
23
+
24
+ /** Index based on a product quantizer. Stored vectors are
25
+ * approximated by PQ codes. */
26
+ struct IndexPQ: Index {
27
+
28
+ /// The product quantizer used to encode the vectors
29
+ ProductQuantizer pq;
30
+
31
+ /// Codes. Size ntotal * pq.code_size
32
+ std::vector<uint8_t> codes;
33
+
34
+ /** Constructor.
35
+ *
36
+ * @param d dimensionality of the input vectors
37
+ * @param M number of subquantizers
38
+ * @param nbits number of bit per subvector index
39
+ */
40
+ IndexPQ (int d, ///< dimensionality of the input vectors
41
+ size_t M, ///< number of subquantizers
42
+ size_t nbits, ///< number of bit per subvector index
43
+ MetricType metric = METRIC_L2);
44
+
45
+ IndexPQ ();
46
+
47
+ void train(idx_t n, const float* x) override;
48
+
49
+ void add(idx_t n, const float* x) override;
50
+
51
+ void search(
52
+ idx_t n,
53
+ const float* x,
54
+ idx_t k,
55
+ float* distances,
56
+ idx_t* labels) const override;
57
+
58
+ void reset() override;
59
+
60
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
61
+
62
+ void reconstruct(idx_t key, float* recons) const override;
63
+
64
+ size_t remove_ids(const IDSelector& sel) override;
65
+
66
+ /* The standalone codec interface */
67
+ size_t sa_code_size () const override;
68
+
69
+ void sa_encode (idx_t n, const float *x,
70
+ uint8_t *bytes) const override;
71
+
72
+ void sa_decode (idx_t n, const uint8_t *bytes,
73
+ float *x) const override;
74
+
75
+
76
+ DistanceComputer * get_distance_computer() const override;
77
+
78
+ /******************************************************
79
+ * Polysemous codes implementation
80
+ ******************************************************/
81
+ bool do_polysemous_training; ///< false = standard PQ
82
+
83
+ /// parameters used for the polysemous training
84
+ PolysemousTraining polysemous_training;
85
+
86
+ /// how to perform the search in search_core
87
+ enum Search_type_t {
88
+ ST_PQ, ///< asymmetric product quantizer (default)
89
+ ST_HE, ///< Hamming distance on codes
90
+ ST_generalized_HE, ///< nb of same codes
91
+ ST_SDC, ///< symmetric product quantizer (SDC)
92
+ ST_polysemous, ///< HE filter (using ht) + PQ combination
93
+ ST_polysemous_generalize, ///< Filter on generalized Hamming
94
+ };
95
+
96
+ Search_type_t search_type;
97
+
98
+ // just encode the sign of the components, instead of using the PQ encoder
99
+ // used only for the queries
100
+ bool encode_signs;
101
+
102
+ /// Hamming threshold used for polysemy
103
+ int polysemous_ht;
104
+
105
+ // actual polysemous search
106
+ void search_core_polysemous (idx_t n, const float *x, idx_t k,
107
+ float *distances, idx_t *labels) const;
108
+
109
+ /// prepare query for a polysemous search, but instead of
110
+ /// computing the result, just get the histogram of Hamming
111
+ /// distances. May be computed on a provided dataset if xb != NULL
112
+ /// @param dist_histogram (M * nbits + 1)
113
+ void hamming_distance_histogram (idx_t n, const float *x,
114
+ idx_t nb, const float *xb,
115
+ int64_t *dist_histogram);
116
+
117
+ /** compute pairwise distances between queries and database
118
+ *
119
+ * @param n nb of query vectors
120
+ * @param x query vector, size n * d
121
+ * @param dis output distances, size n * ntotal
122
+ */
123
+ void hamming_distance_table (idx_t n, const float *x,
124
+ int32_t *dis) const;
125
+
126
+ };
127
+
128
+
129
+ /// statistics are robust to internal threading, but not if
130
+ /// IndexPQ::search is called by multiple threads
131
+ struct IndexPQStats {
132
+ size_t nq; // nb of queries run
133
+ size_t ncode; // nb of codes visited
134
+
135
+ size_t n_hamming_pass; // nb of passed Hamming distance tests (for polysemy)
136
+
137
+ IndexPQStats () {reset (); }
138
+ void reset ();
139
+ };
140
+
141
+ extern IndexPQStats indexPQ_stats;
142
+
143
+
144
+
145
+ /** Quantizer where centroids are virtual: they are the Cartesian
146
+ * product of sub-centroids. */
147
+ struct MultiIndexQuantizer: Index {
148
+ ProductQuantizer pq;
149
+
150
+ MultiIndexQuantizer (int d, ///< dimension of the input vectors
151
+ size_t M, ///< number of subquantizers
152
+ size_t nbits); ///< number of bit per subvector index
153
+
154
+ void train(idx_t n, const float* x) override;
155
+
156
+ void search(
157
+ idx_t n, const float* x, idx_t k,
158
+ float* distances, idx_t* labels) const override;
159
+
160
+ /// add and reset will crash at runtime
161
+ void add(idx_t n, const float* x) override;
162
+ void reset() override;
163
+
164
+ MultiIndexQuantizer () {}
165
+
166
+ void reconstruct(idx_t key, float* recons) const override;
167
+ };
168
+
169
+
170
+ /** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
171
+ */
172
+ struct MultiIndexQuantizer2: MultiIndexQuantizer {
173
+
174
+ /// M Indexes on d / M dimensions
175
+ std::vector<Index*> assign_indexes;
176
+ bool own_fields;
177
+
178
+ MultiIndexQuantizer2 (
179
+ int d, size_t M, size_t nbits,
180
+ Index **indexes);
181
+
182
+ MultiIndexQuantizer2 (
183
+ int d, size_t nbits,
184
+ Index *assign_index_0,
185
+ Index *assign_index_1);
186
+
187
+ void train(idx_t n, const float* x) override;
188
+
189
+ void search(
190
+ idx_t n, const float* x, idx_t k,
191
+ float* distances, idx_t* labels) const override;
192
+
193
+ };
194
+
195
+
196
+ } // namespace faiss
197
+
198
+
199
+ #endif
@@ -0,0 +1,288 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexPreTransform.h>
11
+
12
+ #include <cstdio>
13
+ #include <cmath>
14
+ #include <cstring>
15
+ #include <memory>
16
+
17
+ #include <faiss/utils/utils.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+
20
+ namespace faiss {
21
+
22
+ /*********************************************
23
+ * IndexPreTransform
24
+ *********************************************/
25
+
26
+ IndexPreTransform::IndexPreTransform ():
27
+ index(nullptr), own_fields (false)
28
+ {
29
+ }
30
+
31
+
32
+ IndexPreTransform::IndexPreTransform (
33
+ Index * index):
34
+ Index (index->d, index->metric_type),
35
+ index (index), own_fields (false)
36
+ {
37
+ is_trained = index->is_trained;
38
+ ntotal = index->ntotal;
39
+ }
40
+
41
+
42
+ IndexPreTransform::IndexPreTransform (
43
+ VectorTransform * ltrans,
44
+ Index * index):
45
+ Index (index->d, index->metric_type),
46
+ index (index), own_fields (false)
47
+ {
48
+ is_trained = index->is_trained;
49
+ ntotal = index->ntotal;
50
+ prepend_transform (ltrans);
51
+ }
52
+
53
+ void IndexPreTransform::prepend_transform (VectorTransform *ltrans)
54
+ {
55
+ FAISS_THROW_IF_NOT (ltrans->d_out == d);
56
+ is_trained = is_trained && ltrans->is_trained;
57
+ chain.insert (chain.begin(), ltrans);
58
+ d = ltrans->d_in;
59
+ }
60
+
61
+
62
+ IndexPreTransform::~IndexPreTransform ()
63
+ {
64
+ if (own_fields) {
65
+ for (int i = 0; i < chain.size(); i++)
66
+ delete chain[i];
67
+ delete index;
68
+ }
69
+ }
70
+
71
+
72
+
73
+
74
+ void IndexPreTransform::train (idx_t n, const float *x)
75
+ {
76
+ int last_untrained = 0;
77
+ if (!index->is_trained) {
78
+ last_untrained = chain.size();
79
+ } else {
80
+ for (int i = chain.size() - 1; i >= 0; i--) {
81
+ if (!chain[i]->is_trained) {
82
+ last_untrained = i;
83
+ break;
84
+ }
85
+ }
86
+ }
87
+ const float *prev_x = x;
88
+ ScopeDeleter<float> del;
89
+
90
+ if (verbose) {
91
+ printf("IndexPreTransform::train: training chain 0 to %d\n",
92
+ last_untrained);
93
+ }
94
+
95
+ for (int i = 0; i <= last_untrained; i++) {
96
+
97
+ if (i < chain.size()) {
98
+ VectorTransform *ltrans = chain [i];
99
+ if (!ltrans->is_trained) {
100
+ if (verbose) {
101
+ printf(" Training chain component %d/%zd\n",
102
+ i, chain.size());
103
+ if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
104
+ opqm->verbose = true;
105
+ }
106
+ }
107
+ ltrans->train (n, prev_x);
108
+ }
109
+ } else {
110
+ if (verbose) {
111
+ printf(" Training sub-index\n");
112
+ }
113
+ index->train (n, prev_x);
114
+ }
115
+ if (i == last_untrained) break;
116
+ if (verbose) {
117
+ printf(" Applying transform %d/%zd\n",
118
+ i, chain.size());
119
+ }
120
+
121
+ float * xt = chain[i]->apply (n, prev_x);
122
+
123
+ if (prev_x != x) delete [] prev_x;
124
+ prev_x = xt;
125
+ del.set(xt);
126
+ }
127
+
128
+ is_trained = true;
129
+ }
130
+
131
+
132
+ const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
133
+ {
134
+ const float *prev_x = x;
135
+ ScopeDeleter<float> del;
136
+
137
+ for (int i = 0; i < chain.size(); i++) {
138
+ float * xt = chain[i]->apply (n, prev_x);
139
+ ScopeDeleter<float> del2 (xt);
140
+ del2.swap (del);
141
+ prev_x = xt;
142
+ }
143
+ del.release ();
144
+ return prev_x;
145
+ }
146
+
147
+ void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const
148
+ {
149
+ const float* next_x = xt;
150
+ ScopeDeleter<float> del;
151
+
152
+ for (int i = chain.size() - 1; i >= 0; i--) {
153
+ float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in];
154
+ ScopeDeleter<float> del2 ((prev_x == x) ? nullptr : prev_x);
155
+ chain [i]->reverse_transform (n, next_x, prev_x);
156
+ del2.swap (del);
157
+ next_x = prev_x;
158
+ }
159
+ }
160
+
161
+ void IndexPreTransform::add (idx_t n, const float *x)
162
+ {
163
+ FAISS_THROW_IF_NOT (is_trained);
164
+ const float *xt = apply_chain (n, x);
165
+ ScopeDeleter<float> del(xt == x ? nullptr : xt);
166
+ index->add (n, xt);
167
+ ntotal = index->ntotal;
168
+ }
169
+
170
+ void IndexPreTransform::add_with_ids (idx_t n, const float * x,
171
+ const idx_t *xids)
172
+ {
173
+ FAISS_THROW_IF_NOT (is_trained);
174
+ const float *xt = apply_chain (n, x);
175
+ ScopeDeleter<float> del(xt == x ? nullptr : xt);
176
+ index->add_with_ids (n, xt, xids);
177
+ ntotal = index->ntotal;
178
+ }
179
+
180
+
181
+
182
+
183
+ void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
184
+ float *distances, idx_t *labels) const
185
+ {
186
+ FAISS_THROW_IF_NOT (is_trained);
187
+ const float *xt = apply_chain (n, x);
188
+ ScopeDeleter<float> del(xt == x ? nullptr : xt);
189
+ index->search (n, xt, k, distances, labels);
190
+ }
191
+
192
+ void IndexPreTransform::range_search (idx_t n, const float* x, float radius,
193
+ RangeSearchResult* result) const
194
+ {
195
+ FAISS_THROW_IF_NOT (is_trained);
196
+ const float *xt = apply_chain (n, x);
197
+ ScopeDeleter<float> del(xt == x ? nullptr : xt);
198
+ index->range_search (n, xt, radius, result);
199
+ }
200
+
201
+
202
+
203
+ void IndexPreTransform::reset () {
204
+ index->reset();
205
+ ntotal = 0;
206
+ }
207
+
208
+ size_t IndexPreTransform::remove_ids (const IDSelector & sel) {
209
+ size_t nremove = index->remove_ids (sel);
210
+ ntotal = index->ntotal;
211
+ return nremove;
212
+ }
213
+
214
+
215
+ void IndexPreTransform::reconstruct (idx_t key, float * recons) const
216
+ {
217
+ float *x = chain.empty() ? recons : new float [index->d];
218
+ ScopeDeleter<float> del (recons == x ? nullptr : x);
219
+ // Initial reconstruction
220
+ index->reconstruct (key, x);
221
+
222
+ // Revert transformations from last to first
223
+ reverse_chain (1, x, recons);
224
+ }
225
+
226
+
227
+ void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
228
+ {
229
+ float *x = chain.empty() ? recons : new float [ni * index->d];
230
+ ScopeDeleter<float> del (recons == x ? nullptr : x);
231
+ // Initial reconstruction
232
+ index->reconstruct_n (i0, ni, x);
233
+
234
+ // Revert transformations from last to first
235
+ reverse_chain (ni, x, recons);
236
+ }
237
+
238
+
239
+ void IndexPreTransform::search_and_reconstruct (
240
+ idx_t n, const float *x, idx_t k,
241
+ float *distances, idx_t *labels, float* recons) const
242
+ {
243
+ FAISS_THROW_IF_NOT (is_trained);
244
+
245
+ const float* xt = apply_chain (n, x);
246
+ ScopeDeleter<float> del ((xt == x) ? nullptr : xt);
247
+
248
+ float* recons_temp = chain.empty() ? recons : new float [n * k * index->d];
249
+ ScopeDeleter<float> del2 ((recons_temp == recons) ? nullptr : recons_temp);
250
+ index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp);
251
+
252
+ // Revert transformations from last to first
253
+ reverse_chain (n * k, recons_temp, recons);
254
+ }
255
+
256
+ size_t IndexPreTransform::sa_code_size () const
257
+ {
258
+ return index->sa_code_size ();
259
+ }
260
+
261
+ void IndexPreTransform::sa_encode (idx_t n, const float *x,
262
+ uint8_t *bytes) const
263
+ {
264
+ if (chain.empty()) {
265
+ index->sa_encode (n, x, bytes);
266
+ } else {
267
+ const float *xt = apply_chain (n, x);
268
+ ScopeDeleter<float> del(xt == x ? nullptr : xt);
269
+ index->sa_encode (n, xt, bytes);
270
+ }
271
+ }
272
+
273
+ void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
274
+ float *x) const
275
+ {
276
+ if (chain.empty()) {
277
+ index->sa_decode (n, bytes, x);
278
+ } else {
279
+ std::unique_ptr<float []> x1 (new float [index->d * n]);
280
+ index->sa_decode (n, bytes, x1.get());
281
+ // Revert transformations from last to first
282
+ reverse_chain (n, x1.get(), x);
283
+ }
284
+ }
285
+
286
+
287
+
288
+ } // namespace faiss