faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,199 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+ #ifndef FAISS_LATTICE_ZN_H
10
+ #define FAISS_LATTICE_ZN_H
11
+
12
+ #include <vector>
13
+ #include <stddef.h>
14
+ #include <stdint.h>
15
+
16
+ namespace faiss {
17
+
18
+ /** returns the nearest vertex in the sphere to a query. Returns only
19
+ * the coordinates, not an id.
20
+ *
21
+ * Algorithm: all points are derived from a one atom vector up to a
22
+ * permutation and sign changes. The search function finds the most
23
+ * appropriate atom and transformation.
24
+ */
25
+ struct ZnSphereSearch {
26
+ int dimS, r2;
27
+ int natom;
28
+
29
+ /// size dim * ntatom
30
+ std::vector<float> voc;
31
+
32
+ ZnSphereSearch(int dim, int r2);
33
+
34
+ /// find nearest centroid. x does not need to be normalized
35
+ float search(const float *x, float *c) const;
36
+
37
+ /// full call. Requires externally-allocated temp space
38
+ float search(const float *x, float *c,
39
+ float *tmp, // size 2 *dim
40
+ int *tmp_int, // size dim
41
+ int *ibest_out = nullptr
42
+ ) const;
43
+
44
+ // multi-threaded
45
+ void search_multi(int n, const float *x,
46
+ float *c_out,
47
+ float *dp_out);
48
+
49
+ };
50
+
51
+
52
+ /***************************************************************************
53
+ * Support ids as well.
54
+ *
55
+ * Limitations: ids are limited to 64 bit
56
+ ***************************************************************************/
57
+
58
+ struct EnumeratedVectors {
59
+ /// size of the collection
60
+ uint64_t nv;
61
+ int dim;
62
+
63
+ explicit EnumeratedVectors(int dim): nv(0), dim(dim) {}
64
+
65
+ /// encode a vector from a collection
66
+ virtual uint64_t encode(const float *x) const = 0;
67
+
68
+ /// decode it
69
+ virtual void decode(uint64_t code, float *c) const = 0;
70
+
71
+ // call encode on nc vectors
72
+ void encode_multi (size_t nc, const float *c,
73
+ uint64_t * codes) const;
74
+
75
+ // call decode on nc codes
76
+ void decode_multi (size_t nc, const uint64_t * codes,
77
+ float *c) const;
78
+
79
+ // find the nearest neighbor of each xq
80
+ // (decodes and computes distances)
81
+ void find_nn (size_t n, const uint64_t * codes,
82
+ size_t nq, const float *xq,
83
+ long *idx, float *dis);
84
+
85
+ virtual ~EnumeratedVectors() {}
86
+
87
+ };
88
+
89
+ struct Repeat {
90
+ float val;
91
+ int n;
92
+ };
93
+
94
+ /** Repeats: used to encode a vector that has n occurrences of
95
+ * val. Encodes the signs and permutation of the vector. Useful for
96
+ * atoms.
97
+ */
98
+ struct Repeats {
99
+ int dim;
100
+ std::vector<Repeat> repeats;
101
+
102
+ // initialize from a template of the atom.
103
+ Repeats(int dim = 0, const float *c = nullptr);
104
+
105
+ // count number of possible codes for this atom
106
+ long count() const;
107
+
108
+ long encode(const float *c) const;
109
+
110
+ void decode(uint64_t code, float *c) const;
111
+ };
112
+
113
+
114
+ /** codec that can return ids for the encoded vectors
115
+ *
116
+ * uses the ZnSphereSearch to encode the vector by encoding the
117
+ * permutation and signs. Depends on ZnSphereSearch because it uses
118
+ * the atom numbers */
119
+ struct ZnSphereCodec: ZnSphereSearch, EnumeratedVectors {
120
+
121
+ struct CodeSegment:Repeats {
122
+ explicit CodeSegment(const Repeats & r): Repeats(r) {}
123
+ uint64_t c0; // first code assigned to segment
124
+ int signbits;
125
+ };
126
+
127
+ std::vector<CodeSegment> code_segments;
128
+ uint64_t nv;
129
+ size_t code_size;
130
+
131
+ ZnSphereCodec(int dim, int r2);
132
+
133
+ uint64_t search_and_encode(const float *x) const;
134
+
135
+ void decode(uint64_t code, float *c) const override;
136
+
137
+ /// takes vectors that do not need to be centroids
138
+ uint64_t encode(const float *x) const override;
139
+
140
+ };
141
+
142
+ /** recursive sphere codec
143
+ *
144
+ * Uses a recursive decomposition on the dimensions to encode
145
+ * centroids found by the ZnSphereSearch. The codes are *not*
146
+ * compatible with the ones of ZnSpehreCodec
147
+ */
148
+ struct ZnSphereCodecRec: EnumeratedVectors {
149
+
150
+ int r2;
151
+
152
+ int log2_dim;
153
+ int code_size;
154
+
155
+ ZnSphereCodecRec(int dim, int r2);
156
+
157
+ uint64_t encode_centroid(const float *c) const;
158
+
159
+ void decode(uint64_t code, float *c) const override;
160
+
161
+ /// vectors need to be centroids (does not work on arbitrary
162
+ /// vectors)
163
+ uint64_t encode(const float *x) const override;
164
+
165
+ std::vector<uint64_t> all_nv;
166
+ std::vector<uint64_t> all_nv_cum;
167
+
168
+ int decode_cache_ld;
169
+ std::vector<std::vector<float> > decode_cache;
170
+
171
+ // nb of vectors in the sphere in dim 2^ld with r2 radius
172
+ uint64_t get_nv(int ld, int r2a) const;
173
+
174
+ // cumulative version
175
+ uint64_t get_nv_cum(int ld, int r2t, int r2a) const;
176
+ void set_nv_cum(int ld, int r2t, int r2a, uint64_t v);
177
+
178
+ };
179
+
180
+
181
+ /** Codec that uses the recursive codec if dim is a power of 2 and
182
+ * the regular one otherwise */
183
+ struct ZnSphereCodecAlt: ZnSphereCodec {
184
+ bool use_rec;
185
+ ZnSphereCodecRec znc_rec;
186
+
187
+ ZnSphereCodecAlt (int dim, int r2);
188
+
189
+ uint64_t encode(const float *x) const override;
190
+
191
+ void decode(uint64_t code, float *c) const override;
192
+
193
+ };
194
+
195
+
196
+ };
197
+
198
+
199
+ #endif
@@ -0,0 +1,392 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ /*
11
+ * implementation of Hyper-parameter auto-tuning
12
+ */
13
+
14
+ #include <faiss/AutoTune.h>
15
+
16
+ #include <cmath>
17
+ #include <stdarg.h> /* va_list, va_start, va_arg, va_end */
18
+
19
+
20
+ #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/utils/utils.h>
22
+ #include <faiss/utils/random.h>
23
+
24
+ #include <faiss/IndexFlat.h>
25
+ #include <faiss/VectorTransform.h>
26
+ #include <faiss/IndexPreTransform.h>
27
+ #include <faiss/IndexLSH.h>
28
+ #include <faiss/IndexPQ.h>
29
+ #include <faiss/IndexIVF.h>
30
+ #include <faiss/IndexIVFPQ.h>
31
+ #include <faiss/IndexIVFPQR.h>
32
+ #include <faiss/Index2Layer.h>
33
+ #include <faiss/IndexIVFFlat.h>
34
+ #include <faiss/MetaIndexes.h>
35
+ #include <faiss/IndexScalarQuantizer.h>
36
+ #include <faiss/IndexHNSW.h>
37
+ #include <faiss/IndexLattice.h>
38
+
39
+ #include <faiss/IndexBinaryFlat.h>
40
+ #include <faiss/IndexBinaryHNSW.h>
41
+ #include <faiss/IndexBinaryIVF.h>
42
+
43
+ namespace faiss {
44
+
45
+
46
+ /***************************************************************
47
+ * index_factory
48
+ ***************************************************************/
49
+
50
+ namespace {
51
+
52
+ struct VTChain {
53
+ std::vector<VectorTransform *> chain;
54
+ ~VTChain () {
55
+ for (int i = 0; i < chain.size(); i++) {
56
+ delete chain[i];
57
+ }
58
+ }
59
+ };
60
+
61
+
62
+ /// what kind of training does this coarse quantizer require?
63
+ char get_trains_alone(const Index *coarse_quantizer) {
64
+ return
65
+ dynamic_cast<const MultiIndexQuantizer*>(coarse_quantizer) ? 1 :
66
+ dynamic_cast<const IndexHNSWFlat*>(coarse_quantizer) ? 2 :
67
+ 0;
68
+ }
69
+
70
+
71
+ }
72
+
73
+ Index *index_factory (int d, const char *description_in, MetricType metric)
74
+ {
75
+ FAISS_THROW_IF_NOT(metric == METRIC_L2 ||
76
+ metric == METRIC_INNER_PRODUCT);
77
+ VTChain vts;
78
+ Index *coarse_quantizer = nullptr;
79
+ Index *index = nullptr;
80
+ bool add_idmap = false;
81
+ bool make_IndexRefineFlat = false;
82
+
83
+ ScopeDeleter1<Index> del_coarse_quantizer, del_index;
84
+
85
+ char description[strlen(description_in) + 1];
86
+ char *ptr;
87
+ memcpy (description, description_in, strlen(description_in) + 1);
88
+
89
+ int64_t ncentroids = -1;
90
+ bool use_2layer = false;
91
+
92
+ for (char *tok = strtok_r (description, " ,", &ptr);
93
+ tok;
94
+ tok = strtok_r (nullptr, " ,", &ptr)) {
95
+ int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2;
96
+ std::string stok(tok);
97
+ nbit = 8;
98
+
99
+ // to avoid mem leaks with exceptions:
100
+ // do all tests before any instanciation
101
+
102
+ VectorTransform *vt_1 = nullptr;
103
+ Index *coarse_quantizer_1 = nullptr;
104
+ Index *index_1 = nullptr;
105
+
106
+ // VectorTransforms
107
+ if (sscanf (tok, "PCA%d", &d_out) == 1) {
108
+ vt_1 = new PCAMatrix (d, d_out);
109
+ d = d_out;
110
+ } else if (sscanf (tok, "PCAR%d", &d_out) == 1) {
111
+ vt_1 = new PCAMatrix (d, d_out, 0, true);
112
+ d = d_out;
113
+ } else if (sscanf (tok, "RR%d", &d_out) == 1) {
114
+ vt_1 = new RandomRotationMatrix (d, d_out);
115
+ d = d_out;
116
+ } else if (sscanf (tok, "PCAW%d", &d_out) == 1) {
117
+ vt_1 = new PCAMatrix (d, d_out, -0.5, false);
118
+ d = d_out;
119
+ } else if (sscanf (tok, "PCAWR%d", &d_out) == 1) {
120
+ vt_1 = new PCAMatrix (d, d_out, -0.5, true);
121
+ d = d_out;
122
+ } else if (sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
123
+ vt_1 = new OPQMatrix (d, opq_M, d_out);
124
+ d = d_out;
125
+ } else if (sscanf (tok, "OPQ%d", &opq_M) == 1) {
126
+ vt_1 = new OPQMatrix (d, opq_M);
127
+ } else if (sscanf (tok, "ITQ%d", &d_out) == 1) {
128
+ vt_1 = new ITQTransform (d, d_out, true);
129
+ d = d_out;
130
+ } else if (stok == "ITQ") {
131
+ vt_1 = new ITQTransform (d, d, false);
132
+ } else if (sscanf (tok, "Pad%d", &d_out) == 1) {
133
+ if (d_out > d) {
134
+ vt_1 = new RemapDimensionsTransform (d, d_out, false);
135
+ d = d_out;
136
+ }
137
+ } else if (stok == "L2norm") {
138
+ vt_1 = new NormalizationTransform (d, 2.0);
139
+
140
+ // coarse quantizers
141
+ } else if (!coarse_quantizer &&
142
+ sscanf (tok, "IVF%ld_HNSW%d", &ncentroids, &M) == 2) {
143
+ FAISS_THROW_IF_NOT (metric == METRIC_L2);
144
+ coarse_quantizer_1 = new IndexHNSWFlat (d, M);
145
+
146
+ } else if (!coarse_quantizer &&
147
+ sscanf (tok, "IVF%ld", &ncentroids) == 1) {
148
+ if (metric == METRIC_L2) {
149
+ coarse_quantizer_1 = new IndexFlatL2 (d);
150
+ } else {
151
+ coarse_quantizer_1 = new IndexFlatIP (d);
152
+ }
153
+ } else if (!coarse_quantizer && sscanf (tok, "IMI2x%d", &nbit) == 1) {
154
+ FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2,
155
+ "MultiIndex not implemented for inner prod search");
156
+ coarse_quantizer_1 = new MultiIndexQuantizer (d, 2, nbit);
157
+ ncentroids = 1 << (2 * nbit);
158
+
159
+ } else if (!coarse_quantizer &&
160
+ sscanf (tok, "Residual%dx%d", &M, &nbit) == 2) {
161
+ FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2,
162
+ "MultiIndex not implemented for inner prod search");
163
+ coarse_quantizer_1 = new MultiIndexQuantizer (d, M, nbit);
164
+ ncentroids = int64_t(1) << (M * nbit);
165
+ use_2layer = true;
166
+
167
+ } else if (!coarse_quantizer &&
168
+ sscanf (tok, "Residual%ld", &ncentroids) == 1) {
169
+ coarse_quantizer_1 = new IndexFlatL2 (d);
170
+ use_2layer = true;
171
+
172
+ } else if (stok == "IDMap") {
173
+ add_idmap = true;
174
+
175
+ // IVFs
176
+ } else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
177
+ if (coarse_quantizer) {
178
+ // if there was an IVF in front, then it is an IVFFlat
179
+ IndexIVF *index_ivf = stok == "Flat" ?
180
+ new IndexIVFFlat (
181
+ coarse_quantizer, d, ncentroids, metric) :
182
+ new IndexIVFFlatDedup (
183
+ coarse_quantizer, d, ncentroids, metric);
184
+ index_ivf->quantizer_trains_alone =
185
+ get_trains_alone (coarse_quantizer);
186
+ index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
187
+ del_coarse_quantizer.release ();
188
+ index_ivf->own_fields = true;
189
+ index_1 = index_ivf;
190
+ } else {
191
+ FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup",
192
+ "dedup supported only for IVFFlat");
193
+ index_1 = new IndexFlat (d, metric);
194
+ }
195
+ } else if (!index && (stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
196
+ stok == "SQfp16")) {
197
+ ScalarQuantizer::QuantizerType qt =
198
+ stok == "SQ8" ? ScalarQuantizer::QT_8bit :
199
+ stok == "SQ6" ? ScalarQuantizer::QT_6bit :
200
+ stok == "SQ4" ? ScalarQuantizer::QT_4bit :
201
+ stok == "SQfp16" ? ScalarQuantizer::QT_fp16 :
202
+ ScalarQuantizer::QT_4bit;
203
+ if (coarse_quantizer) {
204
+ FAISS_THROW_IF_NOT (!use_2layer);
205
+ IndexIVFScalarQuantizer *index_ivf =
206
+ new IndexIVFScalarQuantizer (
207
+ coarse_quantizer, d, ncentroids, qt, metric);
208
+ index_ivf->quantizer_trains_alone =
209
+ get_trains_alone (coarse_quantizer);
210
+ del_coarse_quantizer.release ();
211
+ index_ivf->own_fields = true;
212
+ index_1 = index_ivf;
213
+ } else {
214
+ index_1 = new IndexScalarQuantizer (d, qt, metric);
215
+ }
216
+ } else if (!index && sscanf (tok, "PQ%d+%d", &M, &M2) == 2) {
217
+ FAISS_THROW_IF_NOT_MSG(coarse_quantizer,
218
+ "PQ with + works only with an IVF");
219
+ FAISS_THROW_IF_NOT_MSG(metric == METRIC_L2,
220
+ "IVFPQR not implemented for inner product search");
221
+ IndexIVFPQR *index_ivf = new IndexIVFPQR (
222
+ coarse_quantizer, d, ncentroids, M, 8, M2, 8);
223
+ index_ivf->quantizer_trains_alone =
224
+ get_trains_alone (coarse_quantizer);
225
+ del_coarse_quantizer.release ();
226
+ index_ivf->own_fields = true;
227
+ index_1 = index_ivf;
228
+ } else if (!index && (sscanf (tok, "PQ%dx%d", &M, &nbit) == 2 ||
229
+ sscanf (tok, "PQ%d", &M) == 1 ||
230
+ sscanf (tok, "PQ%dnp", &M) == 1)) {
231
+ bool do_polysemous_training = stok.find("np") == std::string::npos;
232
+ if (coarse_quantizer) {
233
+ if (!use_2layer) {
234
+ IndexIVFPQ *index_ivf = new IndexIVFPQ (
235
+ coarse_quantizer, d, ncentroids, M, nbit);
236
+ index_ivf->quantizer_trains_alone =
237
+ get_trains_alone (coarse_quantizer);
238
+ index_ivf->metric_type = metric;
239
+ index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
240
+ del_coarse_quantizer.release ();
241
+ index_ivf->own_fields = true;
242
+ index_ivf->do_polysemous_training = do_polysemous_training;
243
+ index_1 = index_ivf;
244
+ } else {
245
+ Index2Layer *index_2l = new Index2Layer
246
+ (coarse_quantizer, ncentroids, M, nbit);
247
+ index_2l->q1.quantizer_trains_alone =
248
+ get_trains_alone (coarse_quantizer);
249
+ index_2l->q1.own_fields = true;
250
+ index_1 = index_2l;
251
+ }
252
+ } else {
253
+ IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric);
254
+ index_pq->do_polysemous_training = do_polysemous_training;
255
+ index_1 = index_pq;
256
+ }
257
+ } else if (!index &&
258
+ sscanf (tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
259
+ Index * quant = new IndexFlatL2 (d);
260
+ IndexHNSW2Level * hidx2l = new IndexHNSW2Level (quant, ncent, pq_m, M);
261
+ Index2Layer * idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
262
+ idx2l->q1.own_fields = true;
263
+ index_1 = hidx2l;
264
+ } else if (!index &&
265
+ sscanf (tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
266
+ Index * quant = new MultiIndexQuantizer (d, 2, nbit);
267
+ IndexHNSW2Level * hidx2l =
268
+ new IndexHNSW2Level (quant, 1 << (2 * nbit), pq_m, M);
269
+ Index2Layer * idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
270
+ idx2l->q1.own_fields = true;
271
+ idx2l->q1.quantizer_trains_alone = 1;
272
+ index_1 = hidx2l;
273
+ } else if (!index &&
274
+ sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
275
+ index_1 = new IndexHNSWPQ (d, pq_m, M);
276
+ } else if (!index &&
277
+ sscanf (tok, "HNSW%d", &M) == 1) {
278
+ index_1 = new IndexHNSWFlat (d, M);
279
+ } else if (!index &&
280
+ sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
281
+ pq_m == 8) {
282
+ index_1 = new IndexHNSWSQ (d, ScalarQuantizer::QT_8bit, M);
283
+ } else if (!index && (stok == "LSH" || stok == "LSHr" ||
284
+ stok == "LSHrt" || stok == "LSHt")) {
285
+ bool rotate_data = strstr(tok, "r") != nullptr;
286
+ bool train_thresholds = strstr(tok, "t") != nullptr;
287
+ index_1 = new IndexLSH (d, d, rotate_data, train_thresholds);
288
+ } else if (!index &&
289
+ sscanf (tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
290
+ FAISS_THROW_IF_NOT(!coarse_quantizer);
291
+ index_1 = new IndexLattice(d, M, nbit, r2);
292
+ } else if (stok == "RFlat") {
293
+ make_IndexRefineFlat = true;
294
+ } else {
295
+ FAISS_THROW_FMT( "could not parse token \"%s\" in %s\n",
296
+ tok, description_in);
297
+ }
298
+
299
+ if (index_1 && add_idmap) {
300
+ IndexIDMap *idmap = new IndexIDMap(index_1);
301
+ del_index.set (idmap);
302
+ idmap->own_fields = true;
303
+ index_1 = idmap;
304
+ add_idmap = false;
305
+ }
306
+
307
+ if (vt_1) {
308
+ vts.chain.push_back (vt_1);
309
+ }
310
+
311
+ if (coarse_quantizer_1) {
312
+ coarse_quantizer = coarse_quantizer_1;
313
+ del_coarse_quantizer.set (coarse_quantizer);
314
+ }
315
+
316
+ if (index_1) {
317
+ index = index_1;
318
+ del_index.set (index);
319
+ }
320
+ }
321
+
322
+ FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
323
+ description_in);
324
+
325
+ // nothing can go wrong now
326
+ del_index.release ();
327
+ del_coarse_quantizer.release ();
328
+
329
+ if (add_idmap) {
330
+ fprintf(stderr, "index_factory: WARNING: "
331
+ "IDMap option not used\n");
332
+ }
333
+
334
+ if (vts.chain.size() > 0) {
335
+ IndexPreTransform *index_pt = new IndexPreTransform (index);
336
+ index_pt->own_fields = true;
337
+ // add from back
338
+ while (vts.chain.size() > 0) {
339
+ index_pt->prepend_transform (vts.chain.back ());
340
+ vts.chain.pop_back ();
341
+ }
342
+ index = index_pt;
343
+ }
344
+
345
+ if (make_IndexRefineFlat) {
346
+ IndexRefineFlat *index_rf = new IndexRefineFlat (index);
347
+ index_rf->own_fields = true;
348
+ index = index_rf;
349
+ }
350
+
351
+ return index;
352
+ }
353
+
354
+ IndexBinary *index_binary_factory(int d, const char *description)
355
+ {
356
+ IndexBinary *index = nullptr;
357
+
358
+ int ncentroids = -1;
359
+ int M;
360
+
361
+ if (sscanf(description, "BIVF%d_HNSW%d", &ncentroids, &M) == 2) {
362
+ IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
363
+ new IndexBinaryHNSW(d, M), d, ncentroids
364
+ );
365
+ index_ivf->own_fields = true;
366
+ index = index_ivf;
367
+
368
+ } else if (sscanf(description, "BIVF%d", &ncentroids) == 1) {
369
+ IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
370
+ new IndexBinaryFlat(d), d, ncentroids
371
+ );
372
+ index_ivf->own_fields = true;
373
+ index = index_ivf;
374
+
375
+ } else if (sscanf(description, "BHNSW%d", &M) == 1) {
376
+ IndexBinaryHNSW *index_hnsw = new IndexBinaryHNSW(d, M);
377
+ index = index_hnsw;
378
+
379
+ } else if (std::string(description) == "BFlat") {
380
+ index = new IndexBinaryFlat(d);
381
+
382
+ } else {
383
+ FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
384
+ description);
385
+ }
386
+
387
+ return index;
388
+ }
389
+
390
+
391
+
392
+ } // namespace faiss