faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -7,43 +7,48 @@
7
7
 
8
8
  // -*- c++ -*-
9
9
 
10
-
11
10
  #include <faiss/IndexIVFSpectralHash.h>
12
11
 
13
- #include <memory>
14
- #include <algorithm>
15
12
  #include <stdint.h>
13
+ #include <algorithm>
14
+ #include <memory>
16
15
 
16
+ #include <faiss/IndexLSH.h>
17
+ #include <faiss/IndexPreTransform.h>
18
+ #include <faiss/VectorTransform.h>
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
17
21
  #include <faiss/utils/hamming.h>
18
22
  #include <faiss/utils/utils.h>
19
- #include <faiss/impl/FaissAssert.h>
20
- #include <faiss/impl/AuxIndexStructures.h>
21
- #include <faiss/VectorTransform.h>
22
23
 
23
24
  namespace faiss {
24
25
 
25
-
26
- IndexIVFSpectralHash::IndexIVFSpectralHash (
27
- Index * quantizer, size_t d, size_t nlist,
28
- int nbit, float period):
29
- IndexIVF (quantizer, d, nlist, (nbit + 7) / 8, METRIC_L2),
30
- nbit (nbit), period (period), threshold_type (Thresh_global)
31
- {
32
- FAISS_THROW_IF_NOT (code_size % 4 == 0);
33
- RandomRotationMatrix *rr = new RandomRotationMatrix (d, nbit);
34
- rr->init (1234);
26
+ IndexIVFSpectralHash::IndexIVFSpectralHash(
27
+ Index* quantizer,
28
+ size_t d,
29
+ size_t nlist,
30
+ int nbit,
31
+ float period)
32
+ : IndexIVF(quantizer, d, nlist, (nbit + 7) / 8, METRIC_L2),
33
+ nbit(nbit),
34
+ period(period),
35
+ threshold_type(Thresh_global) {
36
+ RandomRotationMatrix* rr = new RandomRotationMatrix(d, nbit);
37
+ rr->init(1234);
35
38
  vt = rr;
36
39
  own_fields = true;
37
40
  is_trained = false;
38
41
  }
39
42
 
40
- IndexIVFSpectralHash::IndexIVFSpectralHash():
41
- IndexIVF(), vt(nullptr), own_fields(false),
42
- nbit(0), period(0), threshold_type(Thresh_global)
43
- {}
43
+ IndexIVFSpectralHash::IndexIVFSpectralHash()
44
+ : IndexIVF(),
45
+ vt(nullptr),
46
+ own_fields(false),
47
+ nbit(0),
48
+ period(0),
49
+ threshold_type(Thresh_global) {}
44
50
 
45
- IndexIVFSpectralHash::~IndexIVFSpectralHash ()
46
- {
51
+ IndexIVFSpectralHash::~IndexIVFSpectralHash() {
47
52
  if (own_fields) {
48
53
  delete vt;
49
54
  }
@@ -51,35 +56,33 @@ IndexIVFSpectralHash::~IndexIVFSpectralHash ()
51
56
 
52
57
  namespace {
53
58
 
54
-
55
- float median (size_t n, float *x) {
59
+ float median(size_t n, float* x) {
56
60
  std::sort(x, x + n);
57
61
  if (n % 2 == 1) {
58
- return x [n / 2];
62
+ return x[n / 2];
59
63
  } else {
60
- return (x [n / 2 - 1] + x [n / 2]) / 2;
64
+ return (x[n / 2 - 1] + x[n / 2]) / 2;
61
65
  }
62
66
  }
63
67
 
64
- }
65
-
68
+ } // namespace
66
69
 
67
- void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
68
- {
70
+ void IndexIVFSpectralHash::train_residual(idx_t n, const float* x) {
69
71
  if (!vt->is_trained) {
70
- vt->train (n, x);
72
+ vt->train(n, x);
71
73
  }
72
74
 
73
75
  if (threshold_type == Thresh_global) {
74
76
  // nothing to do
75
77
  return;
76
- } else if (threshold_type == Thresh_centroid ||
77
- threshold_type == Thresh_centroid_half) {
78
+ } else if (
79
+ threshold_type == Thresh_centroid ||
80
+ threshold_type == Thresh_centroid_half) {
78
81
  // convert all centroids with vt
79
- std::vector<float> centroids (nlist * d);
80
- quantizer->reconstruct_n (0, nlist, centroids.data());
82
+ std::vector<float> centroids(nlist * d);
83
+ quantizer->reconstruct_n(0, nlist, centroids.data());
81
84
  trained.resize(nlist * nbit);
82
- vt->apply_noalloc (nlist, centroids.data(), trained.data());
85
+ vt->apply_noalloc(nlist, centroids.data(), trained.data());
83
86
  if (threshold_type == Thresh_centroid_half) {
84
87
  for (size_t i = 0; i < nlist * nbit; i++) {
85
88
  trained[i] -= 0.25 * period;
@@ -90,12 +93,12 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
90
93
  // otherwise train medians
91
94
 
92
95
  // assign
93
- std::unique_ptr<idx_t []> idx (new idx_t [n]);
94
- quantizer->assign (n, x, idx.get());
96
+ std::unique_ptr<idx_t[]> idx(new idx_t[n]);
97
+ quantizer->assign(n, x, idx.get());
95
98
 
96
99
  std::vector<size_t> sizes(nlist + 1);
97
100
  for (size_t i = 0; i < n; i++) {
98
- FAISS_THROW_IF_NOT (idx[i] >= 0);
101
+ FAISS_THROW_IF_NOT(idx[i] >= 0);
99
102
  sizes[idx[i]]++;
100
103
  }
101
104
 
@@ -107,10 +110,10 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
107
110
  }
108
111
 
109
112
  // transform
110
- std::unique_ptr<float []> xt (vt->apply (n, x));
113
+ std::unique_ptr<float[]> xt(vt->apply(n, x));
111
114
 
112
115
  // transpose + reorder
113
- std::unique_ptr<float []> xo (new float[n * nbit]);
116
+ std::unique_ptr<float[]> xo(new float[n * nbit]);
114
117
 
115
118
  for (size_t i = 0; i < n; i++) {
116
119
  size_t idest = sizes[idx[i]]++;
@@ -119,14 +122,14 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
119
122
  }
120
123
  }
121
124
 
122
- trained.resize (n * nbit);
125
+ trained.resize(n * nbit);
123
126
  // compute medians
124
127
  #pragma omp for
125
128
  for (int i = 0; i < nlist; i++) {
126
129
  size_t i0 = i == 0 ? 0 : sizes[i - 1];
127
130
  size_t i1 = sizes[i];
128
131
  for (int j = 0; j < nbit; j++) {
129
- float *xoi = xo.get() + i0 + n * j;
132
+ float* xoi = xo.get() + i0 + n * j;
130
133
  if (i0 == i1) { // nothing to train
131
134
  trained[i * nbit + j] = 0.0;
132
135
  } else if (i1 == i0 + 1) {
@@ -138,75 +141,71 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
138
141
  }
139
142
  }
140
143
 
141
-
142
144
  namespace {
143
145
 
144
- void binarize_with_freq(size_t nbit, float freq,
145
- const float *x, const float *c,
146
- uint8_t *codes)
147
- {
148
- memset (codes, 0, (nbit + 7) / 8);
146
+ void binarize_with_freq(
147
+ size_t nbit,
148
+ float freq,
149
+ const float* x,
150
+ const float* c,
151
+ uint8_t* codes) {
152
+ memset(codes, 0, (nbit + 7) / 8);
149
153
  for (size_t i = 0; i < nbit; i++) {
150
154
  float xf = (x[i] - c[i]);
151
- int xi = int(floor(xf * freq));
152
- int bit = xi & 1;
155
+ int64_t xi = int64_t(floor(xf * freq));
156
+ int64_t bit = xi & 1;
153
157
  codes[i >> 3] |= bit << (i & 7);
154
158
  }
155
159
  }
156
160
 
161
+ }; // namespace
157
162
 
158
- };
159
-
160
-
161
-
162
- void IndexIVFSpectralHash::encode_vectors(idx_t n, const float* x_in,
163
- const idx_t *list_nos,
164
- uint8_t * codes,
165
- bool include_listnos) const
166
- {
167
- FAISS_THROW_IF_NOT (is_trained);
163
+ void IndexIVFSpectralHash::encode_vectors(
164
+ idx_t n,
165
+ const float* x_in,
166
+ const idx_t* list_nos,
167
+ uint8_t* codes,
168
+ bool include_listnos) const {
169
+ FAISS_THROW_IF_NOT(is_trained);
168
170
  float freq = 2.0 / period;
169
-
170
- FAISS_THROW_IF_NOT_MSG (!include_listnos, "listnos encoding not supported");
171
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
171
172
 
172
173
  // transform with vt
173
- std::unique_ptr<float []> x (vt->apply (n, x_in));
174
+ std::unique_ptr<float[]> x(vt->apply(n, x_in));
174
175
 
175
- #pragma omp parallel
176
- {
177
- std::vector<float> zero (nbit);
176
+ std::vector<float> zero(nbit);
178
177
 
179
- // each thread takes care of a subset of lists
180
178
  #pragma omp for
181
- for (idx_t i = 0; i < n; i++) {
182
- int64_t list_no = list_nos [i];
183
-
184
- if (list_no >= 0) {
185
- const float *c;
186
- if (threshold_type == Thresh_global) {
187
- c = zero.data();
188
- } else {
189
- c = trained.data() + list_no * nbit;
190
- }
191
- binarize_with_freq (nbit, freq,
192
- x.get() + i * nbit, c,
193
- codes + i * code_size) ;
179
+ for (idx_t i = 0; i < n; i++) {
180
+ int64_t list_no = list_nos[i];
181
+ uint8_t* code = codes + i * (code_size + coarse_size);
182
+
183
+ if (list_no >= 0) {
184
+ if (coarse_size) {
185
+ encode_listno(list_no, code);
186
+ }
187
+ const float* c;
188
+
189
+ if (threshold_type == Thresh_global) {
190
+ c = zero.data();
191
+ } else {
192
+ c = trained.data() + list_no * nbit;
194
193
  }
194
+ binarize_with_freq(
195
+ nbit, freq, x.get() + i * nbit, c, code + coarse_size);
196
+ } else {
197
+ memset(code, 0, code_size + coarse_size);
195
198
  }
196
199
  }
197
200
  }
198
201
 
199
202
  namespace {
200
203
 
201
-
202
- template<class HammingComputer>
203
- struct IVFScanner: InvertedListScanner {
204
-
204
+ template <class HammingComputer>
205
+ struct IVFScanner : InvertedListScanner {
205
206
  // copied from index structure
206
- const IndexIVFSpectralHash *index;
207
- size_t code_size;
207
+ const IndexIVFSpectralHash* index;
208
208
  size_t nbit;
209
- bool store_pairs;
210
209
 
211
210
  float period, freq;
212
211
  std::vector<float> q;
@@ -216,61 +215,57 @@ struct IVFScanner: InvertedListScanner {
216
215
 
217
216
  using idx_t = Index::idx_t;
218
217
 
219
- IVFScanner (const IndexIVFSpectralHash * index,
220
- bool store_pairs):
221
- index (index),
222
- code_size(index->code_size),
223
- nbit(index->nbit),
224
- store_pairs(store_pairs),
225
- period(index->period), freq(2.0 / index->period),
226
- q(nbit), zero(nbit), qcode(code_size),
227
- hc(qcode.data(), code_size)
228
- {
218
+ IVFScanner(const IndexIVFSpectralHash* index, bool store_pairs)
219
+ : index(index),
220
+ nbit(index->nbit),
221
+ period(index->period),
222
+ freq(2.0 / index->period),
223
+ q(nbit),
224
+ zero(nbit),
225
+ qcode(index->code_size),
226
+ hc(qcode.data(), index->code_size) {
227
+ this->store_pairs = store_pairs;
228
+ this->code_size = index->code_size;
229
229
  }
230
230
 
231
-
232
- void set_query (const float *query) override {
231
+ void set_query(const float* query) override {
233
232
  FAISS_THROW_IF_NOT(query);
234
233
  FAISS_THROW_IF_NOT(q.size() == nbit);
235
- index->vt->apply_noalloc (1, query, q.data());
234
+ index->vt->apply_noalloc(1, query, q.data());
236
235
 
237
- if (index->threshold_type ==
238
- IndexIVFSpectralHash::Thresh_global) {
239
- binarize_with_freq
240
- (nbit, freq, q.data(), zero.data(), qcode.data());
241
- hc.set (qcode.data(), code_size);
236
+ if (index->threshold_type == IndexIVFSpectralHash::Thresh_global) {
237
+ binarize_with_freq(nbit, freq, q.data(), zero.data(), qcode.data());
238
+ hc.set(qcode.data(), code_size);
242
239
  }
243
240
  }
244
241
 
245
- idx_t list_no;
246
-
247
- void set_list (idx_t list_no, float /*coarse_dis*/) override {
242
+ void set_list(idx_t list_no, float /*coarse_dis*/) override {
248
243
  this->list_no = list_no;
249
244
  if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
250
- const float *c = index->trained.data() + list_no * nbit;
251
- binarize_with_freq (nbit, freq, q.data(), c, qcode.data());
252
- hc.set (qcode.data(), code_size);
245
+ const float* c = index->trained.data() + list_no * nbit;
246
+ binarize_with_freq(nbit, freq, q.data(), c, qcode.data());
247
+ hc.set(qcode.data(), code_size);
253
248
  }
254
249
  }
255
250
 
256
- float distance_to_code (const uint8_t *code) const final {
257
- return hc.hamming (code);
251
+ float distance_to_code(const uint8_t* code) const final {
252
+ return hc.hamming(code);
258
253
  }
259
254
 
260
- size_t scan_codes (size_t list_size,
261
- const uint8_t *codes,
262
- const idx_t *ids,
263
- float *simi, idx_t *idxi,
264
- size_t k) const override
265
- {
255
+ size_t scan_codes(
256
+ size_t list_size,
257
+ const uint8_t* codes,
258
+ const idx_t* ids,
259
+ float* simi,
260
+ idx_t* idxi,
261
+ size_t k) const override {
266
262
  size_t nup = 0;
267
263
  for (size_t j = 0; j < list_size; j++) {
264
+ float dis = hc.hamming(codes);
268
265
 
269
- float dis = hc.hamming (codes);
270
-
271
- if (dis < simi [0]) {
272
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
273
- maxheap_replace_top (k, simi, idxi, dis, id);
266
+ if (dis < simi[0]) {
267
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
268
+ maxheap_replace_top(k, simi, idxi, dis, id);
274
269
  nup++;
275
270
  }
276
271
  codes += code_size;
@@ -278,34 +273,31 @@ struct IVFScanner: InvertedListScanner {
278
273
  return nup;
279
274
  }
280
275
 
281
- void scan_codes_range (size_t list_size,
282
- const uint8_t *codes,
283
- const idx_t *ids,
284
- float radius,
285
- RangeQueryResult & res) const override
286
- {
276
+ void scan_codes_range(
277
+ size_t list_size,
278
+ const uint8_t* codes,
279
+ const idx_t* ids,
280
+ float radius,
281
+ RangeQueryResult& res) const override {
287
282
  for (size_t j = 0; j < list_size; j++) {
288
- float dis = hc.hamming (codes);
283
+ float dis = hc.hamming(codes);
289
284
  if (dis < radius) {
290
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
291
- res.add (dis, id);
285
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
286
+ res.add(dis, id);
292
287
  }
293
288
  codes += code_size;
294
289
  }
295
290
  }
296
-
297
-
298
291
  };
299
292
 
300
293
  } // anonymous namespace
301
294
 
302
- InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
303
- (bool store_pairs) const
304
- {
295
+ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
296
+ bool store_pairs) const {
305
297
  switch (code_size) {
306
298
  #define HANDLE_CODE_SIZE(cs) \
307
- case cs: \
308
- return new IVFScanner<HammingComputer ## cs> (this, store_pairs)
299
+ case cs: \
300
+ return new IVFScanner<HammingComputer##cs>(this, store_pairs)
309
301
  HANDLE_CODE_SIZE(4);
310
302
  HANDLE_CODE_SIZE(8);
311
303
  HANDLE_CODE_SIZE(16);
@@ -314,17 +306,38 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
314
306
  HANDLE_CODE_SIZE(64);
315
307
  #undef HANDLE_CODE_SIZE
316
308
  default:
317
- if (code_size % 8 == 0) {
318
- return new IVFScanner<HammingComputerM8>(this, store_pairs);
319
- } else if (code_size % 4 == 0) {
320
- return new IVFScanner<HammingComputerM4>(this, store_pairs);
321
- } else {
322
- FAISS_THROW_MSG("not supported");
323
- }
309
+ return new IVFScanner<HammingComputerDefault>(this, store_pairs);
324
310
  }
325
-
326
311
  }
327
312
 
313
+ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
314
+ FAISS_THROW_IF_NOT(vt_in->d_out == nbit);
315
+ FAISS_THROW_IF_NOT(vt_in->d_in == d);
316
+ if (own_fields) {
317
+ delete vt;
318
+ }
319
+ vt = vt_in;
320
+ threshold_type = Thresh_global;
321
+ is_trained = quantizer->is_trained && quantizer->ntotal == nlist &&
322
+ vt->is_trained;
323
+ own_fields = own;
324
+ }
328
325
 
326
+ /*
327
+ Check that the encoder is a single vector transform followed by a LSH
328
+ that just does thresholding.
329
+ If this is not the case, the linear transform + threhsolds of the IndexLSH
330
+ should be merged into the VectorTransform (which is feasible).
331
+ */
332
+
333
+ void IndexIVFSpectralHash::replace_vt(IndexPreTransform* encoder, bool own) {
334
+ FAISS_THROW_IF_NOT(encoder->chain.size() == 1);
335
+ auto sub_index = dynamic_cast<IndexLSH*>(encoder->index);
336
+ FAISS_THROW_IF_NOT_MSG(sub_index, "final index should be LSH");
337
+ FAISS_THROW_IF_NOT(sub_index->nbits == nbit);
338
+ FAISS_THROW_IF_NOT(!sub_index->rotate_data);
339
+ FAISS_THROW_IF_NOT(!sub_index->train_thresholds);
340
+ replace_vt(encoder->chain[0], own);
341
+ }
329
342
 
330
- } // namespace faiss
343
+ } // namespace faiss
@@ -10,15 +10,14 @@
10
10
  #ifndef FAISS_INDEX_IVFSH_H
11
11
  #define FAISS_INDEX_IVFSH_H
12
12
 
13
-
14
13
  #include <vector>
15
14
 
16
15
  #include <faiss/IndexIVF.h>
17
16
 
18
-
19
17
  namespace faiss {
20
18
 
21
19
  struct VectorTransform;
20
+ struct IndexPreTransform;
22
21
 
23
22
  /** Inverted list that stores binary codes of size nbit. Before the
24
23
  * binary conversion, the dimension of the vectors is transformed from
@@ -27,49 +26,63 @@ struct VectorTransform;
27
26
  * Each coordinate is subtracted from a value determined by
28
27
  * threshold_type, and split into intervals of size period. Half of
29
28
  * the interval is a 0 bit, the other half a 1.
29
+ *
30
30
  */
31
- struct IndexIVFSpectralHash: IndexIVF {
32
-
33
- VectorTransform *vt; // transformation from d to nbit dim
31
+ struct IndexIVFSpectralHash : IndexIVF {
32
+ /// transformation from d to nbit dim
33
+ VectorTransform* vt;
34
+ /// own the vt
34
35
  bool own_fields;
35
36
 
37
+ /// nb of bits of the binary signature
36
38
  int nbit;
39
+ /// interval size for 0s and 1s
37
40
  float period;
38
41
 
39
42
  enum ThresholdType {
40
- Thresh_global,
41
- Thresh_centroid,
42
- Thresh_centroid_half,
43
- Thresh_median
43
+ Thresh_global, ///< global threshold at 0
44
+ Thresh_centroid, ///< compare to centroid
45
+ Thresh_centroid_half, ///< central interval around centroid
46
+ Thresh_median ///< median of training set
44
47
  };
45
48
  ThresholdType threshold_type;
46
49
 
47
- // size nlist * nbit or 0 if Thresh_global
50
+ /// Trained threshold.
51
+ /// size nlist * nbit or 0 if Thresh_global
48
52
  std::vector<float> trained;
49
53
 
50
- IndexIVFSpectralHash (Index * quantizer, size_t d, size_t nlist,
51
- int nbit, float period);
54
+ IndexIVFSpectralHash(
55
+ Index* quantizer,
56
+ size_t d,
57
+ size_t nlist,
58
+ int nbit,
59
+ float period);
52
60
 
53
- IndexIVFSpectralHash ();
61
+ IndexIVFSpectralHash();
54
62
 
55
63
  void train_residual(idx_t n, const float* x) override;
56
64
 
57
- void encode_vectors(idx_t n, const float* x,
58
- const idx_t *list_nos,
59
- uint8_t * codes,
60
- bool include_listnos = false) const override;
65
+ void encode_vectors(
66
+ idx_t n,
67
+ const float* x,
68
+ const idx_t* list_nos,
69
+ uint8_t* codes,
70
+ bool include_listnos = false) const override;
61
71
 
62
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
63
- const override;
72
+ InvertedListScanner* get_InvertedListScanner(
73
+ bool store_pairs) const override;
64
74
 
65
- ~IndexIVFSpectralHash () override;
75
+ /** replace the vector transform for an empty (and possibly untrained) index
76
+ */
77
+ void replace_vt(VectorTransform* vt, bool own = false);
66
78
 
67
- };
79
+ /** convenience function to get the VT from an index constucted by an
80
+ * index_factory (should end in "LSH") */
81
+ void replace_vt(IndexPreTransform* index, bool own = false);
68
82
 
83
+ ~IndexIVFSpectralHash() override;
84
+ };
69
85
 
70
-
71
-
72
- }; // namespace faiss
73
-
86
+ } // namespace faiss
74
87
 
75
88
  #endif