faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -5,92 +5,80 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  /*
11
- * implementation of Hyper-parameter auto-tuning
9
+ * implementation of the index_factory function. Lots of regex parsing code.
12
10
  */
13
11
 
14
- #include <faiss/AutoTune.h>
12
+ #include <faiss/index_factory.h>
13
+ #include "faiss/MetricType.h"
14
+ #include "faiss/impl/FaissAssert.h"
15
15
 
16
16
  #include <cinttypes>
17
17
  #include <cmath>
18
18
 
19
+ #include <map>
20
+
21
+ #include <regex>
22
+
19
23
  #include <faiss/impl/FaissAssert.h>
20
- #include <faiss/utils/utils.h>
21
24
  #include <faiss/utils/random.h>
25
+ #include <faiss/utils/utils.h>
22
26
 
27
+ #include <faiss/Index2Layer.h>
28
+ #include <faiss/IndexAdditiveQuantizer.h>
23
29
  #include <faiss/IndexFlat.h>
24
- #include <faiss/VectorTransform.h>
25
- #include <faiss/IndexPreTransform.h>
26
- #include <faiss/IndexLSH.h>
27
- #include <faiss/IndexPQ.h>
30
+ #include <faiss/IndexHNSW.h>
28
31
  #include <faiss/IndexIVF.h>
32
+ #include <faiss/IndexIVFAdditiveQuantizer.h>
33
+ #include <faiss/IndexIVFFlat.h>
29
34
  #include <faiss/IndexIVFPQ.h>
35
+ #include <faiss/IndexIVFPQFastScan.h>
30
36
  #include <faiss/IndexIVFPQR.h>
31
- #include <faiss/Index2Layer.h>
32
- #include <faiss/IndexIVFFlat.h>
33
- #include <faiss/MetaIndexes.h>
34
- #include <faiss/IndexScalarQuantizer.h>
35
- #include <faiss/IndexHNSW.h>
37
+ #include <faiss/IndexIVFSpectralHash.h>
38
+ #include <faiss/IndexLSH.h>
36
39
  #include <faiss/IndexLattice.h>
40
+ #include <faiss/IndexNSG.h>
41
+ #include <faiss/IndexPQ.h>
37
42
  #include <faiss/IndexPQFastScan.h>
38
- #include <faiss/IndexIVFPQFastScan.h>
43
+ #include <faiss/IndexPreTransform.h>
39
44
  #include <faiss/IndexRefine.h>
40
-
45
+ #include <faiss/IndexScalarQuantizer.h>
46
+ #include <faiss/MetaIndexes.h>
47
+ #include <faiss/VectorTransform.h>
41
48
 
42
49
  #include <faiss/IndexBinaryFlat.h>
43
50
  #include <faiss/IndexBinaryHNSW.h>
44
- #include <faiss/IndexBinaryIVF.h>
45
51
  #include <faiss/IndexBinaryHash.h>
52
+ #include <faiss/IndexBinaryIVF.h>
53
+ #include <string>
46
54
 
47
55
  namespace faiss {
48
56
 
49
-
50
57
  /***************************************************************
51
58
  * index_factory
52
59
  ***************************************************************/
53
60
 
54
- namespace {
61
+ int index_factory_verbose = 0;
55
62
 
56
- struct VTChain {
57
- std::vector<VectorTransform *> chain;
58
- ~VTChain () {
59
- for (int i = 0; i < chain.size(); i++) {
60
- delete chain[i];
61
- }
62
- }
63
- };
64
-
65
-
66
- /// what kind of training does this coarse quantizer require?
67
- char get_trains_alone(const Index *coarse_quantizer) {
68
- return
69
- dynamic_cast<const IndexFlat*>(coarse_quantizer) ? 0 :
70
- // multi index just needs to be quantized
71
- dynamic_cast<const MultiIndexQuantizer*>(coarse_quantizer) ? 1 :
72
- dynamic_cast<const IndexHNSWFlat*>(coarse_quantizer) ? 2 :
73
- 2; // for complicated indexes, we assume they can't be used as a kmeans index
74
- }
63
+ namespace {
75
64
 
76
- bool str_ends_with(const std::string& s, const std::string& suffix)
77
- {
78
- return s.rfind(suffix) == std::abs(int(s.size()-suffix.size()));
79
- }
65
+ /***************************************************************
66
+ * Small functions
67
+ */
80
68
 
81
- // check if ends with suffix followed by digits
82
- bool str_ends_with_digits(const std::string& s, const std::string& suffix)
83
- {
84
- int i;
85
- for(i = s.length() - 1; i >= 0; i--) {
86
- if (!isdigit(s[i])) break;
87
- }
88
- return str_ends_with(s.substr(0, i + 1), suffix);
69
+ bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
70
+ return std::regex_match(s, sm, std::regex(pat));
89
71
  }
90
72
 
91
- void find_matching_parentheses(const std::string &s, int & i0, int & i1) {
73
+ // find first pair of matching parentheses
74
+ void find_matching_parentheses(
75
+ const std::string& s,
76
+ int& i0,
77
+ int& i1,
78
+ int begin = 0) {
92
79
  int st = 0;
93
- for (int i = 0; i < s.length(); i++) {
80
+ i0 = i1 = 0;
81
+ for (int i = begin; i < s.length(); i++) {
94
82
  if (s[i] == '(') {
95
83
  if (st == 0) {
96
84
  i0 = i;
@@ -105,413 +93,647 @@ void find_matching_parentheses(const std::string &s, int & i0, int & i1) {
105
93
  return;
106
94
  }
107
95
  if (st < 0) {
108
- FAISS_THROW_FMT("factory string %s: unbalanced parentheses", s.c_str());
96
+ FAISS_THROW_FMT(
97
+ "factory string %s: unbalanced parentheses", s.c_str());
109
98
  }
110
99
  }
100
+ }
101
+ FAISS_THROW_FMT(
102
+ "factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
103
+ }
111
104
 
105
+ /// what kind of training does this coarse quantizer require?
106
+ char get_trains_alone(const Index* coarse_quantizer) {
107
+ if (dynamic_cast<const IndexFlat*>(coarse_quantizer)) {
108
+ return 0;
109
+ }
110
+ // multi index just needs to be quantized
111
+ if (dynamic_cast<const MultiIndexQuantizer*>(coarse_quantizer) ||
112
+ dynamic_cast<const ResidualCoarseQuantizer*>(coarse_quantizer)) {
113
+ return 1;
112
114
  }
113
- FAISS_THROW_FMT("factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
115
+ if (dynamic_cast<const IndexHNSWFlat*>(coarse_quantizer)) {
116
+ return 2;
117
+ }
118
+ return 2; // for complicated indexes, we assume they can't be used as a
119
+ // kmeans index
120
+ }
114
121
 
122
+ // set the fields for factory-constructed IVF structures
123
+ IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
124
+ index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
125
+ index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
126
+ index_ivf->own_fields = true;
127
+ return index_ivf;
115
128
  }
116
129
 
117
- } // anonymous namespace
130
+ int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
131
+ if (mr.length() == 0) {
132
+ return deflt;
133
+ }
134
+ return std::stoi(mr.str().substr(begin));
135
+ }
118
136
 
119
- Index *index_factory (int d, const char *description_in, MetricType metric)
120
- {
121
- FAISS_THROW_IF_NOT(metric == METRIC_L2 ||
122
- metric == METRIC_INNER_PRODUCT);
123
- VTChain vts;
124
- Index *coarse_quantizer = nullptr;
125
- std::string parenthesis_ivf, parenthesis_refine;
126
- Index *index = nullptr;
127
- bool add_idmap = false;
128
- int d_in = d;
129
-
130
- ScopeDeleter1<Index> del_coarse_quantizer, del_index;
131
-
132
- std::string description(description_in);
133
- char *ptr;
134
-
135
- // handle indexes in parentheses
136
- while (description.find('(') != std::string::npos) {
137
- // then we make a sub-index and remove the () from the description
138
- int i0, i1;
139
- find_matching_parentheses(description, i0, i1);
137
+ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
138
+ {"SQ8", ScalarQuantizer::QT_8bit},
139
+ {"SQ4", ScalarQuantizer::QT_4bit},
140
+ {"SQ6", ScalarQuantizer::QT_6bit},
141
+ {"SQfp16", ScalarQuantizer::QT_fp16},
142
+ };
143
+ const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
144
+
145
+ std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
146
+ {"_Nfloat", AdditiveQuantizer::ST_norm_float},
147
+ {"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
148
+ {"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
149
+ {"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
150
+ {"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
151
+ {"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
152
+ };
140
153
 
141
- std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
154
+ const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
155
+ const std::string aq_norm_pattern =
156
+ "(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4)";
142
157
 
143
- if (str_ends_with_digits(description.substr(0, i0), "IVF")) {
144
- parenthesis_ivf = sub_description;
145
- } else if (str_ends_with(description.substr(0, i0), "Refine")) {
146
- parenthesis_refine = sub_description;
147
- } else {
148
- FAISS_THROW_MSG("don't know what to do with parenthesis index");
149
- }
150
- description = description.erase(i0, i1 - i0 + 1);
151
- }
152
-
153
- int64_t ncentroids = -1;
154
- bool use_2layer = false;
155
- int hnsw_M = -1;
156
-
157
- for (char *tok = strtok_r (&description[0], " ,", &ptr);
158
- tok;
159
- tok = strtok_r (nullptr, " ,", &ptr)) {
160
- int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2;
161
- std::string stok(tok);
162
- nbit = 8;
163
- int bbs = -1;
164
- char c;
165
-
166
- // to avoid mem leaks with exceptions:
167
- // do all tests before any instanciation
168
-
169
- VectorTransform *vt_1 = nullptr;
170
- Index *coarse_quantizer_1 = nullptr;
171
- Index *index_1 = nullptr;
172
-
173
- // VectorTransforms
174
- if (sscanf (tok, "PCA%d", &d_out) == 1) {
175
- vt_1 = new PCAMatrix (d, d_out);
176
- d = d_out;
177
- } else if (sscanf (tok, "PCAR%d", &d_out) == 1) {
178
- vt_1 = new PCAMatrix (d, d_out, 0, true);
179
- d = d_out;
180
- } else if (sscanf (tok, "RR%d", &d_out) == 1) {
181
- vt_1 = new RandomRotationMatrix (d, d_out);
182
- d = d_out;
183
- } else if (sscanf (tok, "PCAW%d", &d_out) == 1) {
184
- vt_1 = new PCAMatrix (d, d_out, -0.5, false);
185
- d = d_out;
186
- } else if (sscanf (tok, "PCAWR%d", &d_out) == 1) {
187
- vt_1 = new PCAMatrix (d, d_out, -0.5, true);
188
- d = d_out;
189
- } else if (sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
190
- vt_1 = new OPQMatrix (d, opq_M, d_out);
191
- d = d_out;
192
- } else if (sscanf (tok, "OPQ%d", &opq_M) == 1) {
193
- vt_1 = new OPQMatrix (d, opq_M);
194
- } else if (sscanf (tok, "ITQ%d", &d_out) == 1) {
195
- vt_1 = new ITQTransform (d, d_out, true);
196
- d = d_out;
197
- } else if (stok == "ITQ") {
198
- vt_1 = new ITQTransform (d, d, false);
199
- } else if (sscanf (tok, "Pad%d", &d_out) == 1) {
200
- if (d_out > d) {
201
- vt_1 = new RemapDimensionsTransform (d, d_out, false);
202
- d = d_out;
203
- }
204
- } else if (stok == "L2norm") {
205
- vt_1 = new NormalizationTransform (d, 2.0);
206
-
207
- // coarse quantizers
208
- } else if (!coarse_quantizer &&
209
- sscanf (tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
210
- coarse_quantizer_1 = new IndexHNSWFlat (d, M, metric);
211
-
212
- } else if (!coarse_quantizer &&
213
- sscanf (tok, "IVF%" PRId64, &ncentroids) == 1) {
214
- if (!parenthesis_ivf.empty()) {
215
- coarse_quantizer_1 =
216
- index_factory(d, parenthesis_ivf.c_str(), metric);
217
-
218
- } else if (metric == METRIC_L2) {
219
- coarse_quantizer_1 = new IndexFlatL2 (d);
220
- } else {
221
- coarse_quantizer_1 = new IndexFlatIP (d);
222
- }
158
+ AdditiveQuantizer::Search_type_t aq_parse_search_type(
159
+ std::string stok,
160
+ MetricType metric) {
161
+ if (stok == "") {
162
+ return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
163
+ : AdditiveQuantizer::ST_LUT_nonorm;
164
+ }
165
+ int pos = stok.rfind("_");
166
+ return aq_search_type[stok.substr(pos)];
167
+ }
223
168
 
224
- } else if (!coarse_quantizer && sscanf (tok, "IMI2x%d", &nbit) == 1) {
225
- FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2,
226
- "MultiIndex not implemented for inner prod search");
227
- coarse_quantizer_1 = new MultiIndexQuantizer (d, 2, nbit);
228
- ncentroids = 1 << (2 * nbit);
229
-
230
- } else if (!coarse_quantizer &&
231
- sscanf (tok, "Residual%dx%d", &M, &nbit) == 2) {
232
- FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2,
233
- "MultiIndex not implemented for inner prod search");
234
- coarse_quantizer_1 = new MultiIndexQuantizer (d, M, nbit);
235
- ncentroids = int64_t(1) << (M * nbit);
236
- use_2layer = true;
237
-
238
- } else if (!coarse_quantizer &&
239
- sscanf (tok, "Residual%" PRId64, &ncentroids) == 1) {
240
- coarse_quantizer_1 = new IndexFlatL2 (d);
241
- use_2layer = true;
242
-
243
- } else if (stok == "IDMap") {
244
- add_idmap = true;
245
-
246
- // IVFs
247
- } else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
248
- if (coarse_quantizer) {
249
- // if there was an IVF in front, then it is an IVFFlat
250
- IndexIVF *index_ivf = stok == "Flat" ?
251
- new IndexIVFFlat (
252
- coarse_quantizer, d, ncentroids, metric) :
253
- new IndexIVFFlatDedup (
254
- coarse_quantizer, d, ncentroids, metric);
255
- index_ivf->quantizer_trains_alone =
256
- get_trains_alone (coarse_quantizer);
257
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
258
- del_coarse_quantizer.release ();
259
- index_ivf->own_fields = true;
260
- index_1 = index_ivf;
261
- } else if (hnsw_M > 0) {
262
- index_1 = new IndexHNSWFlat (d, hnsw_M, metric);
263
- } else {
264
- FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup",
265
- "dedup supported only for IVFFlat");
266
- index_1 = new IndexFlat (d, metric);
267
- }
268
- } else if (!index && (stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
269
- stok == "SQfp16")) {
270
- ScalarQuantizer::QuantizerType qt =
271
- stok == "SQ8" ? ScalarQuantizer::QT_8bit :
272
- stok == "SQ6" ? ScalarQuantizer::QT_6bit :
273
- stok == "SQ4" ? ScalarQuantizer::QT_4bit :
274
- stok == "SQfp16" ? ScalarQuantizer::QT_fp16 :
275
- ScalarQuantizer::QT_4bit;
276
- if (coarse_quantizer) {
277
- FAISS_THROW_IF_NOT (!use_2layer);
278
- IndexIVFScalarQuantizer *index_ivf =
279
- new IndexIVFScalarQuantizer (
280
- coarse_quantizer, d, ncentroids, qt, metric);
281
- index_ivf->quantizer_trains_alone =
282
- get_trains_alone (coarse_quantizer);
283
- del_coarse_quantizer.release ();
284
- index_ivf->own_fields = true;
285
- index_1 = index_ivf;
286
- } else if (hnsw_M > 0) {
287
- index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
288
- } else {
289
- index_1 = new IndexScalarQuantizer (d, qt, metric);
290
- }
291
- } else if (!index && sscanf (tok, "PQ%d+%d", &M, &M2) == 2) {
292
- FAISS_THROW_IF_NOT_MSG(coarse_quantizer,
293
- "PQ with + works only with an IVF");
294
- FAISS_THROW_IF_NOT_MSG(metric == METRIC_L2,
295
- "IVFPQR not implemented for inner product search");
296
- IndexIVFPQR *index_ivf = new IndexIVFPQR (
297
- coarse_quantizer, d, ncentroids, M, 8, M2, 8);
298
- index_ivf->quantizer_trains_alone =
299
- get_trains_alone (coarse_quantizer);
300
- del_coarse_quantizer.release ();
301
- index_ivf->own_fields = true;
302
- index_1 = index_ivf;
303
- } else if (!index && (
304
- sscanf (tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
305
- (sscanf (tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
306
- (sscanf (tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
307
- if (bbs == -1) {
308
- bbs = 32;
309
- }
310
- bool by_residual = str_ends_with(stok, "fsr");
311
- if (coarse_quantizer) {
312
- IndexIVFPQFastScan *index_ivf = new IndexIVFPQFastScan(
313
- coarse_quantizer, d, ncentroids, M, 4, metric, bbs
314
- );
315
- index_ivf->quantizer_trains_alone =
316
- get_trains_alone (coarse_quantizer);
317
- index_ivf->metric_type = metric;
318
- index_ivf->by_residual = by_residual;
319
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
320
- del_coarse_quantizer.release ();
321
- index_ivf->own_fields = true;
322
- index_1 = index_ivf;
323
- } else {
324
- IndexPQFastScan *index_pq = new IndexPQFastScan (
325
- d, M, 4, metric, bbs
326
- );
327
- index_1 = index_pq;
328
- }
329
- } else if (!index && (sscanf (tok, "PQ%dx%d", &M, &nbit) == 2 ||
330
- sscanf (tok, "PQ%d", &M) == 1 ||
331
- sscanf (tok, "PQ%dnp", &M) == 1)) {
332
- bool do_polysemous_training = stok.find("np") == std::string::npos;
333
- if (coarse_quantizer) {
334
- if (!use_2layer) {
335
- IndexIVFPQ *index_ivf = new IndexIVFPQ (
336
- coarse_quantizer, d, ncentroids, M, nbit);
337
- index_ivf->quantizer_trains_alone =
338
- get_trains_alone (coarse_quantizer);
339
- index_ivf->metric_type = metric;
340
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
341
- del_coarse_quantizer.release ();
342
- index_ivf->own_fields = true;
343
- index_ivf->do_polysemous_training = do_polysemous_training;
344
- index_1 = index_ivf;
345
- } else {
346
- Index2Layer *index_2l = new Index2Layer
347
- (coarse_quantizer, ncentroids, M, nbit);
348
- index_2l->q1.quantizer_trains_alone =
349
- get_trains_alone (coarse_quantizer);
350
- index_2l->q1.own_fields = true;
351
- index_1 = index_2l;
352
- }
353
- } else if (hnsw_M > 0) {
354
- IndexHNSWPQ *ipq = new IndexHNSWPQ(d, M, hnsw_M);
355
- dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
356
- do_polysemous_training;
357
- index_1 = ipq;
358
- } else {
359
- IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric);
360
- index_pq->do_polysemous_training = do_polysemous_training;
361
- index_1 = index_pq;
362
- }
363
- } else if (!index &&
364
- sscanf (tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
365
- Index * quant = new IndexFlatL2 (d);
366
- IndexHNSW2Level * hidx2l = new IndexHNSW2Level (quant, ncent, pq_m, M);
367
- Index2Layer * idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
368
- idx2l->q1.own_fields = true;
369
- index_1 = hidx2l;
370
- } else if (!index &&
371
- sscanf (tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
372
- Index * quant = new MultiIndexQuantizer (d, 2, nbit);
373
- IndexHNSW2Level * hidx2l =
374
- new IndexHNSW2Level (quant, 1 << (2 * nbit), pq_m, M);
375
- Index2Layer * idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
376
- idx2l->q1.own_fields = true;
377
- idx2l->q1.quantizer_trains_alone = 1;
378
- index_1 = hidx2l;
379
- } else if (!index &&
380
- sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
381
- index_1 = new IndexHNSWPQ (d, pq_m, M);
382
- } else if (!index &&
383
- sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
384
- pq_m == 8) {
385
- index_1 = new IndexHNSWSQ (d, ScalarQuantizer::QT_8bit, M);
386
- } else if (!index &&
387
- sscanf (tok, "HNSW%d", &M) == 1) {
388
- hnsw_M = M;
389
- // here it is unclear what we want: HNSW flat or HNSWx,Y ?
390
- } else if (!index && (stok == "LSH" || stok == "LSHr" ||
391
- stok == "LSHrt" || stok == "LSHt")) {
392
- bool rotate_data = strstr(tok, "r") != nullptr;
393
- bool train_thresholds = strstr(tok, "t") != nullptr;
394
- index_1 = new IndexLSH (d, d, rotate_data, train_thresholds);
395
- } else if (!index &&
396
- sscanf (tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
397
- FAISS_THROW_IF_NOT(!coarse_quantizer);
398
- index_1 = new IndexLattice(d, M, nbit, r2);
399
- } else if (stok == "RFlat") {
400
- parenthesis_refine = "Flat";
401
- } else if (stok == "Refine") {
402
- FAISS_THROW_IF_NOT_MSG(
403
- !parenthesis_refine.empty(),
404
- "Refine index should be provided in parentheses"
405
- );
169
+ std::vector<size_t> aq_parse_nbits(std::string stok) {
170
+ std::vector<size_t> nbits;
171
+ std::smatch sm;
172
+ while (std::regex_search(stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
173
+ int M = std::stoi(sm[1].str());
174
+ int nbit = std::stoi(sm[2].str());
175
+ nbits.resize(nbits.size() + M, nbit);
176
+ stok = sm.suffix();
177
+ }
178
+ return nbits;
179
+ }
180
+
181
+ /***************************************************************
182
+ * Parse VectorTransform
183
+ */
184
+
185
+ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
186
+ std::smatch sm;
187
+ auto match = [&sm, description](std::string pattern) {
188
+ return re_match(description, pattern, sm);
189
+ };
190
+ if (match("PCA(W?)(R?)([0-9]+)")) {
191
+ bool white = sm[1].length() > 0;
192
+ bool rot = sm[2].length() > 0;
193
+ return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
194
+ }
195
+ if (match("L2[nN]orm")) {
196
+ return new NormalizationTransform(d, 2.0);
197
+ }
198
+ if (match("RR([0-9]+)?")) {
199
+ return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
200
+ }
201
+ if (match("ITQ([0-9]+)?")) {
202
+ return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
203
+ }
204
+ if (match("OPQ([0-9]+)(_[0-9]+)?")) {
205
+ int M = std::stoi(sm[1].str());
206
+ int d_out = mres_to_int(sm[2], d, 1);
207
+ return new OPQMatrix(d, M, d_out);
208
+ }
209
+ if (match("Pad([0-9]+)")) {
210
+ int d_out = std::stoi(sm[1].str());
211
+ return new RemapDimensionsTransform(d, std::max(d_out, d), false);
212
+ }
213
+ return nullptr;
214
+ };
215
+
216
+ /***************************************************************
217
+ * Parse IndexIVF
218
+ */
219
+
220
+ // parsing guard + function
221
+ Index* parse_coarse_quantizer(
222
+ const std::string& description,
223
+ int d,
224
+ MetricType mt,
225
+ std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
226
+ size_t& nlist,
227
+ bool& use_2layer) {
228
+ std::smatch sm;
229
+ auto match = [&sm, description](std::string pattern) {
230
+ return re_match(description, pattern, sm);
231
+ };
232
+ use_2layer = false;
233
+
234
+ if (match("IVF([0-9]+)")) {
235
+ nlist = std::stoi(sm[1].str());
236
+ return new IndexFlat(d, mt);
237
+ }
238
+ if (match("IMI2x([0-9]+)")) {
239
+ int nbit = std::stoi(sm[1].str());
240
+ FAISS_THROW_IF_NOT_MSG(
241
+ mt == METRIC_L2,
242
+ "MultiIndex not implemented for inner prod search");
243
+ nlist = (size_t)1 << (2 * nbit);
244
+ return new MultiIndexQuantizer(d, 2, nbit);
245
+ }
246
+ if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
247
+ nlist = std::stoi(sm[1].str());
248
+ int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
249
+ return new IndexHNSWFlat(d, hnsw_M, mt);
250
+ }
251
+ if (match("IVF([0-9]+)_NSG([0-9]+)")) {
252
+ nlist = std::stoi(sm[1].str());
253
+ int R = std::stoi(sm[2]);
254
+ return new IndexNSGFlat(d, R, mt);
255
+ }
256
+ if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
257
+ nlist = std::stoi(sm[1].str());
258
+ int no = std::stoi(sm[2].str());
259
+ FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
260
+ return parenthesis_indexes[no].release();
261
+ }
262
+
263
+ // these two generate Index2Layer's not IndexIVF's
264
+ if (match("Residual([0-9]+)x([0-9]+)")) {
265
+ FAISS_THROW_IF_NOT_MSG(
266
+ mt == METRIC_L2,
267
+ "MultiIndex not implemented for inner prod search");
268
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
269
+ nlist = (size_t)1 << (M * nbit);
270
+ use_2layer = true;
271
+ return new MultiIndexQuantizer(d, M, nbit);
272
+ }
273
+ if (match("Residual([0-9]+)")) {
274
+ FAISS_THROW_IF_NOT_MSG(
275
+ mt == METRIC_L2,
276
+ "Residual not implemented for inner prod search");
277
+ use_2layer = true;
278
+ nlist = mres_to_int(sm[1]);
279
+ return new IndexFlatL2(d);
280
+ }
281
+ return nullptr;
282
+ }
283
+
284
+ // parse the code description part of the IVF description
285
+
286
+ IndexIVF* parse_IndexIVF(
287
+ const std::string& code_string,
288
+ std::unique_ptr<Index>& quantizer,
289
+ size_t nlist,
290
+ MetricType mt) {
291
+ std::smatch sm;
292
+ auto match = [&sm, &code_string](const std::string pattern) {
293
+ return re_match(code_string, pattern, sm);
294
+ };
295
+ auto get_q = [&quantizer] { return quantizer.release(); };
296
+ int d = quantizer->d;
297
+
298
+ if (match("Flat")) {
299
+ return new IndexIVFFlat(get_q(), d, nlist, mt);
300
+ }
301
+ if (match("FlatDedup")) {
302
+ return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
303
+ }
304
+ if (match(sq_pattern)) {
305
+ return new IndexIVFScalarQuantizer(
306
+ get_q(), d, nlist, sq_types[sm[1].str()], mt);
307
+ }
308
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
309
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
310
+ IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
311
+ index_ivf->do_polysemous_training = sm[3].str() != "np";
312
+ return index_ivf;
313
+ }
314
+ if (match("PQ([0-9]+)\\+([0-9]+)")) {
315
+ FAISS_THROW_IF_NOT_MSG(
316
+ mt == METRIC_L2,
317
+ "IVFPQR not implemented for inner product search");
318
+ int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
319
+ return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
320
+ }
321
+ if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
322
+ int M = mres_to_int(sm[1]);
323
+ int bbs = mres_to_int(sm[3], 32, 1);
324
+ IndexIVFPQFastScan* index_ivf =
325
+ new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
326
+ index_ivf->by_residual = sm[2].str() == "r";
327
+ return index_ivf;
328
+ }
329
+ if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
330
+ std::vector<size_t> nbits = aq_parse_nbits(sm.str());
331
+ AdditiveQuantizer::Search_type_t st =
332
+ aq_parse_search_type(sm[sm.size() - 1].str(), mt);
333
+ IndexIVF* index_ivf;
334
+ if (sm[1].str() == "RQ") {
335
+ index_ivf = new IndexIVFResidualQuantizer(
336
+ get_q(), d, nlist, nbits, mt, st);
406
337
  } else {
407
- FAISS_THROW_FMT( "could not parse token \"%s\" in %s\n",
408
- tok, description_in);
338
+ FAISS_THROW_IF_NOT(nbits.size() > 0);
339
+ index_ivf = new IndexIVFLocalSearchQuantizer(
340
+ get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
341
+ }
342
+ return index_ivf;
343
+ }
344
+ if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
345
+ int outdim = mres_to_int(sm[2], d); // is also the number of bits
346
+ std::unique_ptr<VectorTransform> vt;
347
+ if (sm[1] == "ITQ") {
348
+ vt.reset(new ITQTransform(d, outdim, d != outdim));
349
+ } else if (sm[1] == "PCA") {
350
+ vt.reset(new PCAMatrix(d, outdim));
351
+ } else if (sm[1] == "PCAR") {
352
+ vt.reset(new PCAMatrix(d, outdim, 0, true));
353
+ }
354
+ // the rationale for -1e10 is that this corresponds to simple
355
+ // thresholding
356
+ float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
357
+ IndexIVFSpectralHash* index_ivf =
358
+ new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
359
+ index_ivf->replace_vt(vt.release(), true);
360
+ if (sm[4].length()) {
361
+ std::string s = sm[4].str();
362
+ index_ivf->threshold_type = s == "g"
363
+ ? IndexIVFSpectralHash::Thresh_global
364
+ : s == "c"
365
+ ? IndexIVFSpectralHash::Thresh_centroid
366
+ :
367
+ /* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
409
368
  }
369
+ return index_ivf;
370
+ }
371
+ return nullptr;
372
+ }
410
373
 
411
- if (index_1 && add_idmap) {
412
- IndexIDMap *idmap = new IndexIDMap(index_1);
413
- del_index.set (idmap);
414
- idmap->own_fields = true;
415
- index_1 = idmap;
416
- add_idmap = false;
374
+ /***************************************************************
375
+ * Parse IndexHNSW
376
+ */
377
+
378
+ IndexHNSW* parse_IndexHNSW(
379
+ const std::string code_string,
380
+ int d,
381
+ MetricType mt,
382
+ int hnsw_M) {
383
+ std::smatch sm;
384
+ auto match = [&sm, &code_string](const std::string& pattern) {
385
+ return re_match(code_string, pattern, sm);
386
+ };
387
+
388
+ if (match("Flat|")) {
389
+ return new IndexHNSWFlat(d, hnsw_M, mt);
390
+ }
391
+ if (match("PQ([0-9]+)(np)?")) {
392
+ int M = std::stoi(sm[1].str());
393
+ IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
394
+ dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
395
+ sm[2].str() != "np";
396
+ return ipq;
397
+ }
398
+ if (match(sq_pattern)) {
399
+ return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
400
+ }
401
+ if (match("([0-9]+)\\+PQ([0-9]+)?")) {
402
+ int ncent = mres_to_int(sm[1]);
403
+ int pq_m = mres_to_int(sm[2]);
404
+ IndexHNSW2Level* hidx2l =
405
+ new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
406
+ dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
407
+ return hidx2l;
408
+ }
409
+ if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
410
+ int nbit = mres_to_int(sm[1]);
411
+ int pq_m = mres_to_int(sm[2]);
412
+ Index* quant = new MultiIndexQuantizer(d, 2, nbit);
413
+ IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
414
+ quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
415
+ Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
416
+ idx2l->q1.own_fields = true;
417
+ idx2l->q1.quantizer_trains_alone = 1;
418
+ return hidx2l;
419
+ }
420
+
421
+ return nullptr;
422
+ }
423
+
424
+ /***************************************************************
425
+ * Parse basic indexes
426
+ */
427
+
428
+ Index* parse_other_indexes(
429
+ const std::string& description,
430
+ int d,
431
+ MetricType metric) {
432
+ std::smatch sm;
433
+ auto match = [&sm, description](const std::string& pattern) {
434
+ return re_match(description, pattern, sm);
435
+ };
436
+
437
+ // IndexFlat
438
+ if (description == "Flat") {
439
+ return new IndexFlat(d, metric);
440
+ }
441
+
442
+ // IndexLSH
443
+ if (match("LSH(r?)(t?)")) {
444
+ bool rotate_data = sm[1].length() > 0;
445
+ bool train_thresholds = sm[2].length() > 0;
446
+ FAISS_THROW_IF_NOT(metric == METRIC_L2);
447
+ return new IndexLSH(d, d, rotate_data, train_thresholds);
448
+ }
449
+
450
+ // IndexLattice
451
+ if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
452
+ int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
453
+ int nbit = std::stoi(sm[3].str());
454
+ return new IndexLattice(d, M, nbit, r2);
455
+ }
456
+
457
+ // IndexNSGFlat
458
+ if (match("NSG([0-9]+)(,Flat)?")) {
459
+ return new IndexNSGFlat(d, std::stoi(sm[1].str()), metric);
460
+ }
461
+
462
+ // IndexScalarQuantizer
463
+ if (match(sq_pattern)) {
464
+ return new IndexScalarQuantizer(d, sq_types[description], metric);
465
+ }
466
+
467
+ // IndexPQ
468
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
469
+ int M = std::stoi(sm[1].str());
470
+ int nbit = mres_to_int(sm[2], 8, 1);
471
+ IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
472
+ index_pq->do_polysemous_training = sm[3].str() != "np";
473
+ return index_pq;
474
+ }
475
+
476
+ // IndexPQFastScan
477
+ if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
478
+ int M = std::stoi(sm[1].str());
479
+ int bbs = mres_to_int(sm[2], 32, 1);
480
+ return new IndexPQFastScan(d, M, 4, metric, bbs);
481
+ }
482
+
483
+ // IndexResidualCoarseQuantizer and IndexResidualQuantizer
484
+ std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
485
+ if (match(pattern)) {
486
+ std::vector<size_t> nbits = aq_parse_nbits(description);
487
+ if (sm[1].str() == "RCQ") {
488
+ return new ResidualCoarseQuantizer(d, nbits, metric);
417
489
  }
490
+ AdditiveQuantizer::Search_type_t st =
491
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
492
+ return new IndexResidualQuantizer(d, nbits, metric, st);
493
+ }
418
494
 
419
- if (vt_1) {
420
- vts.chain.push_back (vt_1);
495
+ // LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
496
+ if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
497
+ std::vector<size_t> nbits = aq_parse_nbits(description);
498
+ int M = mres_to_int(sm[2]);
499
+ int nbit = mres_to_int(sm[3]);
500
+ if (sm[1].str() == "LSCQ") {
501
+ return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
421
502
  }
503
+ AdditiveQuantizer::Search_type_t st =
504
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
505
+ return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
506
+ }
507
+
508
+ return nullptr;
509
+ }
510
+
511
+ /***************************************************************
512
+ * Driver function
513
+ */
514
+ std::unique_ptr<Index> index_factory_sub(
515
+ int d,
516
+ std::string description,
517
+ MetricType metric) {
518
+ // handle composite indexes
519
+
520
+ bool verbose = index_factory_verbose;
422
521
 
423
- if (coarse_quantizer_1) {
424
- coarse_quantizer = coarse_quantizer_1;
425
- del_coarse_quantizer.set (coarse_quantizer);
522
+ if (verbose) {
523
+ printf("begin parse VectorTransforms: %s \n", description.c_str());
524
+ }
525
+
526
+ // for the current match
527
+ std::smatch sm;
528
+
529
+ // handle refines
530
+ if (re_match(description, "(.+),RFlat", sm) ||
531
+ re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
532
+ std::unique_ptr<Index> filter_index =
533
+ index_factory_sub(d, sm[1].str(), metric);
534
+ std::unique_ptr<Index> refine_index;
535
+
536
+ if (sm.size() == 3) { // Refine
537
+ refine_index = index_factory_sub(d, sm[2].str(), metric);
538
+ } else { // RFlat
539
+ refine_index.reset(new IndexFlat(d, metric));
426
540
  }
541
+ IndexRefine* index_rf =
542
+ new IndexRefine(filter_index.get(), refine_index.get());
543
+ index_rf->own_fields = true;
544
+ filter_index.release();
545
+ refine_index.release();
546
+ index_rf->own_refine_index = true;
547
+ return std::unique_ptr<Index>(index_rf);
548
+ }
427
549
 
428
- if (index_1) {
429
- index = index_1;
430
- del_index.set (index);
550
+ // IndexPreTransform
551
+ // should handle this first (even before parentheses) because it changes d
552
+ std::vector<std::unique_ptr<VectorTransform>> vts;
553
+ VectorTransform* vt = nullptr;
554
+ while (re_match(description, "([^,]+),(.*)", sm) &&
555
+ (vt = parse_VectorTransform(sm[1], d))) {
556
+ // reset loop
557
+ description = sm[sm.size() - 1];
558
+ vts.emplace_back(vt);
559
+ d = vts.back()->d_out;
560
+ }
561
+
562
+ if (vts.size() > 0) {
563
+ std::unique_ptr<Index> sub_index =
564
+ index_factory_sub(d, description, metric);
565
+ IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
566
+ std::unique_ptr<Index> ret(index_pt);
567
+ index_pt->own_fields = true;
568
+ sub_index.release();
569
+ while (vts.size() > 0) {
570
+ if (verbose) {
571
+ printf("prepend trans %d -> %d\n",
572
+ vts.back()->d_in,
573
+ vts.back()->d_out);
574
+ }
575
+ index_pt->prepend_transform(vts.back().release());
576
+ vts.pop_back();
431
577
  }
578
+ return ret;
432
579
  }
433
580
 
434
- if (!index && hnsw_M > 0) {
435
- index = new IndexHNSWFlat (d, hnsw_M, metric);
436
- del_index.set (index);
581
+ // what we got from the parentheses
582
+ std::vector<std::unique_ptr<Index>> parenthesis_indexes;
583
+
584
+ int begin = 0;
585
+ while (description.find('(', begin) != std::string::npos) {
586
+ // replace indexes in () with Index0, Index1, etc.
587
+ int i0, i1;
588
+ find_matching_parentheses(description, i0, i1, begin);
589
+ std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
590
+ int no = parenthesis_indexes.size();
591
+ parenthesis_indexes.push_back(
592
+ index_factory_sub(d, sub_description, metric));
593
+ description = description.substr(0, i0 + 1) + "Index" +
594
+ std::to_string(no) + description.substr(i1);
595
+ begin = i1 + 1;
437
596
  }
438
597
 
439
- FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
440
- description_in);
598
+ if (verbose) {
599
+ printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
600
+ description.c_str(),
601
+ parenthesis_indexes.size(),
602
+ d);
603
+ }
441
604
 
442
- // nothing can go wrong now
443
- del_index.release ();
444
- del_coarse_quantizer.release ();
605
+ // IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
606
+ // support both
607
+ if (re_match(description, "(.+),IDMap", sm) ||
608
+ re_match(description, "IDMap,(.+)", sm)) {
609
+ IndexIDMap* idmap = new IndexIDMap(
610
+ index_factory_sub(d, sm[1].str(), metric).release());
611
+ idmap->own_fields = true;
612
+ return std::unique_ptr<Index>(idmap);
613
+ }
445
614
 
446
- if (add_idmap) {
447
- fprintf(stderr, "index_factory: WARNING: "
448
- "IDMap option not used\n");
615
+ { // handle basic index types
616
+ Index* index = parse_other_indexes(description, d, metric);
617
+ if (index) {
618
+ return std::unique_ptr<Index>(index);
619
+ }
449
620
  }
450
621
 
451
- if (vts.chain.size() > 0) {
452
- IndexPreTransform *index_pt = new IndexPreTransform (index);
453
- index_pt->own_fields = true;
454
- // add from back
455
- while (vts.chain.size() > 0) {
456
- index_pt->prepend_transform (vts.chain.back ());
457
- vts.chain.pop_back ();
622
+ // HNSW variants (it was unclear in the old version that the separator was a
623
+ // "," so we support both "_" and ",")
624
+ if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
625
+ int hnsw_M = mres_to_int(sm[1], 32);
626
+ // We also accept empty code string (synonym of Flat)
627
+ std::string code_string =
628
+ sm[2].length() > 0 ? sm[2].str().substr(1) : "";
629
+ if (verbose) {
630
+ printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
631
+ description.c_str(),
632
+ code_string.c_str(),
633
+ hnsw_M);
458
634
  }
459
- index = index_pt;
635
+
636
+ IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
637
+ FAISS_THROW_IF_NOT_FMT(
638
+ index,
639
+ "could not parse HNSW code description %s in %s",
640
+ code_string.c_str(),
641
+ description.c_str());
642
+ return std::unique_ptr<Index>(index);
460
643
  }
461
644
 
462
- if (!parenthesis_refine.empty()) {
463
- Index *refine_index = index_factory(d_in, parenthesis_refine.c_str(), metric);
464
- IndexRefine *index_rf = new IndexRefine(index, refine_index);
465
- index_rf->own_refine_index = true;
466
- index_rf->own_fields = true;
467
- index = index_rf;
645
+ // IndexIVF
646
+ {
647
+ size_t nlist;
648
+ bool use_2layer;
649
+ size_t comma = description.find(",");
650
+ std::string coarse_string = description.substr(0, comma);
651
+ // Match coarse quantizer part first
652
+ std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
653
+ description.substr(0, comma),
654
+ d,
655
+ metric,
656
+ parenthesis_indexes,
657
+ nlist,
658
+ use_2layer));
659
+
660
+ if (comma != std::string::npos && quantizer.get()) {
661
+ std::string code_description = description.substr(comma + 1);
662
+ if (use_2layer) {
663
+ bool ok =
664
+ re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
665
+ FAISS_THROW_IF_NOT_FMT(
666
+ ok,
667
+ "could not parse 2 layer code description %s in %s",
668
+ code_description.c_str(),
669
+ description.c_str());
670
+ int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
671
+ Index2Layer* index_2l =
672
+ new Index2Layer(quantizer.release(), nlist, M, nbit);
673
+ index_2l->q1.own_fields = true;
674
+ index_2l->q1.quantizer_trains_alone =
675
+ get_trains_alone(index_2l->q1.quantizer);
676
+ return std::unique_ptr<Index>(index_2l);
677
+ }
678
+
679
+ IndexIVF* index_ivf =
680
+ parse_IndexIVF(code_description, quantizer, nlist, metric);
681
+
682
+ FAISS_THROW_IF_NOT_FMT(
683
+ index_ivf,
684
+ "could not parse code description %s in %s",
685
+ code_description.c_str(),
686
+ description.c_str());
687
+ return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
688
+ }
468
689
  }
690
+ FAISS_THROW_FMT("could not parse index string %s", description.c_str());
691
+ return nullptr;
692
+ }
469
693
 
470
- return index;
694
+ } // anonymous namespace
695
+
696
+ Index* index_factory(int d, const char* description, MetricType metric) {
697
+ return index_factory_sub(d, description, metric).release();
471
698
  }
472
699
 
473
- IndexBinary *index_binary_factory(int d, const char *description)
474
- {
475
- IndexBinary *index = nullptr;
700
+ IndexBinary* index_binary_factory(int d, const char* description) {
701
+ IndexBinary* index = nullptr;
476
702
 
477
703
  int ncentroids = -1;
478
704
  int M, nhash, b;
479
705
 
480
706
  if (sscanf(description, "BIVF%d_HNSW%d", &ncentroids, &M) == 2) {
481
- IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
482
- new IndexBinaryHNSW(d, M), d, ncentroids
483
- );
707
+ IndexBinaryIVF* index_ivf =
708
+ new IndexBinaryIVF(new IndexBinaryHNSW(d, M), d, ncentroids);
484
709
  index_ivf->own_fields = true;
485
710
  index = index_ivf;
486
711
 
487
712
  } else if (sscanf(description, "BIVF%d", &ncentroids) == 1) {
488
- IndexBinaryIVF *index_ivf = new IndexBinaryIVF(
489
- new IndexBinaryFlat(d), d, ncentroids
490
- );
713
+ IndexBinaryIVF* index_ivf =
714
+ new IndexBinaryIVF(new IndexBinaryFlat(d), d, ncentroids);
491
715
  index_ivf->own_fields = true;
492
716
  index = index_ivf;
493
717
 
494
718
  } else if (sscanf(description, "BHNSW%d", &M) == 1) {
495
- IndexBinaryHNSW *index_hnsw = new IndexBinaryHNSW(d, M);
719
+ IndexBinaryHNSW* index_hnsw = new IndexBinaryHNSW(d, M);
496
720
  index = index_hnsw;
497
721
 
498
722
  } else if (sscanf(description, "BHash%dx%d", &nhash, &b) == 2) {
499
- index = new IndexBinaryMultiHash (d, nhash, b);
723
+ index = new IndexBinaryMultiHash(d, nhash, b);
500
724
 
501
725
  } else if (sscanf(description, "BHash%d", &b) == 1) {
502
- index = new IndexBinaryHash (d, b);
726
+ index = new IndexBinaryHash(d, b);
503
727
 
504
728
  } else if (std::string(description) == "BFlat") {
505
729
  index = new IndexBinaryFlat(d);
506
730
 
507
731
  } else {
508
- FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index",
509
- description);
732
+ FAISS_THROW_IF_NOT_FMT(
733
+ index, "description %s did not generate an index", description);
510
734
  }
511
735
 
512
736
  return index;
513
737
  }
514
738
 
515
-
516
-
517
739
  } // namespace faiss