faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -7,43 +7,48 @@
7
7
 
8
8
  // -*- c++ -*-
9
9
 
10
-
11
10
  #include <faiss/IndexIVFSpectralHash.h>
12
11
 
13
- #include <memory>
14
- #include <algorithm>
15
12
  #include <stdint.h>
13
+ #include <algorithm>
14
+ #include <memory>
16
15
 
16
+ #include <faiss/IndexLSH.h>
17
+ #include <faiss/IndexPreTransform.h>
18
+ #include <faiss/VectorTransform.h>
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
17
21
  #include <faiss/utils/hamming.h>
18
22
  #include <faiss/utils/utils.h>
19
- #include <faiss/impl/FaissAssert.h>
20
- #include <faiss/impl/AuxIndexStructures.h>
21
- #include <faiss/VectorTransform.h>
22
23
 
23
24
  namespace faiss {
24
25
 
25
-
26
- IndexIVFSpectralHash::IndexIVFSpectralHash (
27
- Index * quantizer, size_t d, size_t nlist,
28
- int nbit, float period):
29
- IndexIVF (quantizer, d, nlist, (nbit + 7) / 8, METRIC_L2),
30
- nbit (nbit), period (period), threshold_type (Thresh_global)
31
- {
32
- FAISS_THROW_IF_NOT (code_size % 4 == 0);
33
- RandomRotationMatrix *rr = new RandomRotationMatrix (d, nbit);
34
- rr->init (1234);
26
+ IndexIVFSpectralHash::IndexIVFSpectralHash(
27
+ Index* quantizer,
28
+ size_t d,
29
+ size_t nlist,
30
+ int nbit,
31
+ float period)
32
+ : IndexIVF(quantizer, d, nlist, (nbit + 7) / 8, METRIC_L2),
33
+ nbit(nbit),
34
+ period(period),
35
+ threshold_type(Thresh_global) {
36
+ RandomRotationMatrix* rr = new RandomRotationMatrix(d, nbit);
37
+ rr->init(1234);
35
38
  vt = rr;
36
39
  own_fields = true;
37
40
  is_trained = false;
38
41
  }
39
42
 
40
- IndexIVFSpectralHash::IndexIVFSpectralHash():
41
- IndexIVF(), vt(nullptr), own_fields(false),
42
- nbit(0), period(0), threshold_type(Thresh_global)
43
- {}
43
+ IndexIVFSpectralHash::IndexIVFSpectralHash()
44
+ : IndexIVF(),
45
+ vt(nullptr),
46
+ own_fields(false),
47
+ nbit(0),
48
+ period(0),
49
+ threshold_type(Thresh_global) {}
44
50
 
45
- IndexIVFSpectralHash::~IndexIVFSpectralHash ()
46
- {
51
+ IndexIVFSpectralHash::~IndexIVFSpectralHash() {
47
52
  if (own_fields) {
48
53
  delete vt;
49
54
  }
@@ -51,35 +56,33 @@ IndexIVFSpectralHash::~IndexIVFSpectralHash ()
51
56
 
52
57
  namespace {
53
58
 
54
-
55
- float median (size_t n, float *x) {
59
+ float median(size_t n, float* x) {
56
60
  std::sort(x, x + n);
57
61
  if (n % 2 == 1) {
58
- return x [n / 2];
62
+ return x[n / 2];
59
63
  } else {
60
- return (x [n / 2 - 1] + x [n / 2]) / 2;
64
+ return (x[n / 2 - 1] + x[n / 2]) / 2;
61
65
  }
62
66
  }
63
67
 
64
- }
65
-
68
+ } // namespace
66
69
 
67
- void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
68
- {
70
+ void IndexIVFSpectralHash::train_residual(idx_t n, const float* x) {
69
71
  if (!vt->is_trained) {
70
- vt->train (n, x);
72
+ vt->train(n, x);
71
73
  }
72
74
 
73
75
  if (threshold_type == Thresh_global) {
74
76
  // nothing to do
75
77
  return;
76
- } else if (threshold_type == Thresh_centroid ||
77
- threshold_type == Thresh_centroid_half) {
78
+ } else if (
79
+ threshold_type == Thresh_centroid ||
80
+ threshold_type == Thresh_centroid_half) {
78
81
  // convert all centroids with vt
79
- std::vector<float> centroids (nlist * d);
80
- quantizer->reconstruct_n (0, nlist, centroids.data());
82
+ std::vector<float> centroids(nlist * d);
83
+ quantizer->reconstruct_n(0, nlist, centroids.data());
81
84
  trained.resize(nlist * nbit);
82
- vt->apply_noalloc (nlist, centroids.data(), trained.data());
85
+ vt->apply_noalloc(nlist, centroids.data(), trained.data());
83
86
  if (threshold_type == Thresh_centroid_half) {
84
87
  for (size_t i = 0; i < nlist * nbit; i++) {
85
88
  trained[i] -= 0.25 * period;
@@ -90,12 +93,12 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
90
93
  // otherwise train medians
91
94
 
92
95
  // assign
93
- std::unique_ptr<idx_t []> idx (new idx_t [n]);
94
- quantizer->assign (n, x, idx.get());
96
+ std::unique_ptr<idx_t[]> idx(new idx_t[n]);
97
+ quantizer->assign(n, x, idx.get());
95
98
 
96
99
  std::vector<size_t> sizes(nlist + 1);
97
100
  for (size_t i = 0; i < n; i++) {
98
- FAISS_THROW_IF_NOT (idx[i] >= 0);
101
+ FAISS_THROW_IF_NOT(idx[i] >= 0);
99
102
  sizes[idx[i]]++;
100
103
  }
101
104
 
@@ -107,10 +110,10 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
107
110
  }
108
111
 
109
112
  // transform
110
- std::unique_ptr<float []> xt (vt->apply (n, x));
113
+ std::unique_ptr<float[]> xt(vt->apply(n, x));
111
114
 
112
115
  // transpose + reorder
113
- std::unique_ptr<float []> xo (new float[n * nbit]);
116
+ std::unique_ptr<float[]> xo(new float[n * nbit]);
114
117
 
115
118
  for (size_t i = 0; i < n; i++) {
116
119
  size_t idest = sizes[idx[i]]++;
@@ -119,14 +122,14 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
119
122
  }
120
123
  }
121
124
 
122
- trained.resize (n * nbit);
125
+ trained.resize(n * nbit);
123
126
  // compute medians
124
127
  #pragma omp for
125
128
  for (int i = 0; i < nlist; i++) {
126
129
  size_t i0 = i == 0 ? 0 : sizes[i - 1];
127
130
  size_t i1 = sizes[i];
128
131
  for (int j = 0; j < nbit; j++) {
129
- float *xoi = xo.get() + i0 + n * j;
132
+ float* xoi = xo.get() + i0 + n * j;
130
133
  if (i0 == i1) { // nothing to train
131
134
  trained[i * nbit + j] = 0.0;
132
135
  } else if (i1 == i0 + 1) {
@@ -138,75 +141,71 @@ void IndexIVFSpectralHash::train_residual (idx_t n, const float *x)
138
141
  }
139
142
  }
140
143
 
141
-
142
144
  namespace {
143
145
 
144
- void binarize_with_freq(size_t nbit, float freq,
145
- const float *x, const float *c,
146
- uint8_t *codes)
147
- {
148
- memset (codes, 0, (nbit + 7) / 8);
146
+ void binarize_with_freq(
147
+ size_t nbit,
148
+ float freq,
149
+ const float* x,
150
+ const float* c,
151
+ uint8_t* codes) {
152
+ memset(codes, 0, (nbit + 7) / 8);
149
153
  for (size_t i = 0; i < nbit; i++) {
150
154
  float xf = (x[i] - c[i]);
151
- int xi = int(floor(xf * freq));
152
- int bit = xi & 1;
155
+ int64_t xi = int64_t(floor(xf * freq));
156
+ int64_t bit = xi & 1;
153
157
  codes[i >> 3] |= bit << (i & 7);
154
158
  }
155
159
  }
156
160
 
161
+ }; // namespace
157
162
 
158
- };
159
-
160
-
161
-
162
- void IndexIVFSpectralHash::encode_vectors(idx_t n, const float* x_in,
163
- const idx_t *list_nos,
164
- uint8_t * codes,
165
- bool include_listnos) const
166
- {
167
- FAISS_THROW_IF_NOT (is_trained);
163
+ void IndexIVFSpectralHash::encode_vectors(
164
+ idx_t n,
165
+ const float* x_in,
166
+ const idx_t* list_nos,
167
+ uint8_t* codes,
168
+ bool include_listnos) const {
169
+ FAISS_THROW_IF_NOT(is_trained);
168
170
  float freq = 2.0 / period;
169
-
170
- FAISS_THROW_IF_NOT_MSG (!include_listnos, "listnos encoding not supported");
171
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
171
172
 
172
173
  // transform with vt
173
- std::unique_ptr<float []> x (vt->apply (n, x_in));
174
+ std::unique_ptr<float[]> x(vt->apply(n, x_in));
174
175
 
175
- #pragma omp parallel
176
- {
177
- std::vector<float> zero (nbit);
176
+ std::vector<float> zero(nbit);
178
177
 
179
- // each thread takes care of a subset of lists
180
178
  #pragma omp for
181
- for (idx_t i = 0; i < n; i++) {
182
- int64_t list_no = list_nos [i];
183
-
184
- if (list_no >= 0) {
185
- const float *c;
186
- if (threshold_type == Thresh_global) {
187
- c = zero.data();
188
- } else {
189
- c = trained.data() + list_no * nbit;
190
- }
191
- binarize_with_freq (nbit, freq,
192
- x.get() + i * nbit, c,
193
- codes + i * code_size) ;
179
+ for (idx_t i = 0; i < n; i++) {
180
+ int64_t list_no = list_nos[i];
181
+ uint8_t* code = codes + i * (code_size + coarse_size);
182
+
183
+ if (list_no >= 0) {
184
+ if (coarse_size) {
185
+ encode_listno(list_no, code);
186
+ }
187
+ const float* c;
188
+
189
+ if (threshold_type == Thresh_global) {
190
+ c = zero.data();
191
+ } else {
192
+ c = trained.data() + list_no * nbit;
194
193
  }
194
+ binarize_with_freq(
195
+ nbit, freq, x.get() + i * nbit, c, code + coarse_size);
196
+ } else {
197
+ memset(code, 0, code_size + coarse_size);
195
198
  }
196
199
  }
197
200
  }
198
201
 
199
202
  namespace {
200
203
 
201
-
202
- template<class HammingComputer>
203
- struct IVFScanner: InvertedListScanner {
204
-
204
+ template <class HammingComputer>
205
+ struct IVFScanner : InvertedListScanner {
205
206
  // copied from index structure
206
- const IndexIVFSpectralHash *index;
207
- size_t code_size;
207
+ const IndexIVFSpectralHash* index;
208
208
  size_t nbit;
209
- bool store_pairs;
210
209
 
211
210
  float period, freq;
212
211
  std::vector<float> q;
@@ -216,61 +215,57 @@ struct IVFScanner: InvertedListScanner {
216
215
 
217
216
  using idx_t = Index::idx_t;
218
217
 
219
- IVFScanner (const IndexIVFSpectralHash * index,
220
- bool store_pairs):
221
- index (index),
222
- code_size(index->code_size),
223
- nbit(index->nbit),
224
- store_pairs(store_pairs),
225
- period(index->period), freq(2.0 / index->period),
226
- q(nbit), zero(nbit), qcode(code_size),
227
- hc(qcode.data(), code_size)
228
- {
218
+ IVFScanner(const IndexIVFSpectralHash* index, bool store_pairs)
219
+ : index(index),
220
+ nbit(index->nbit),
221
+ period(index->period),
222
+ freq(2.0 / index->period),
223
+ q(nbit),
224
+ zero(nbit),
225
+ qcode(index->code_size),
226
+ hc(qcode.data(), index->code_size) {
227
+ this->store_pairs = store_pairs;
228
+ this->code_size = index->code_size;
229
229
  }
230
230
 
231
-
232
- void set_query (const float *query) override {
231
+ void set_query(const float* query) override {
233
232
  FAISS_THROW_IF_NOT(query);
234
233
  FAISS_THROW_IF_NOT(q.size() == nbit);
235
- index->vt->apply_noalloc (1, query, q.data());
234
+ index->vt->apply_noalloc(1, query, q.data());
236
235
 
237
- if (index->threshold_type ==
238
- IndexIVFSpectralHash::Thresh_global) {
239
- binarize_with_freq
240
- (nbit, freq, q.data(), zero.data(), qcode.data());
241
- hc.set (qcode.data(), code_size);
236
+ if (index->threshold_type == IndexIVFSpectralHash::Thresh_global) {
237
+ binarize_with_freq(nbit, freq, q.data(), zero.data(), qcode.data());
238
+ hc.set(qcode.data(), code_size);
242
239
  }
243
240
  }
244
241
 
245
- idx_t list_no;
246
-
247
- void set_list (idx_t list_no, float /*coarse_dis*/) override {
242
+ void set_list(idx_t list_no, float /*coarse_dis*/) override {
248
243
  this->list_no = list_no;
249
244
  if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
250
- const float *c = index->trained.data() + list_no * nbit;
251
- binarize_with_freq (nbit, freq, q.data(), c, qcode.data());
252
- hc.set (qcode.data(), code_size);
245
+ const float* c = index->trained.data() + list_no * nbit;
246
+ binarize_with_freq(nbit, freq, q.data(), c, qcode.data());
247
+ hc.set(qcode.data(), code_size);
253
248
  }
254
249
  }
255
250
 
256
- float distance_to_code (const uint8_t *code) const final {
257
- return hc.hamming (code);
251
+ float distance_to_code(const uint8_t* code) const final {
252
+ return hc.hamming(code);
258
253
  }
259
254
 
260
- size_t scan_codes (size_t list_size,
261
- const uint8_t *codes,
262
- const idx_t *ids,
263
- float *simi, idx_t *idxi,
264
- size_t k) const override
265
- {
255
+ size_t scan_codes(
256
+ size_t list_size,
257
+ const uint8_t* codes,
258
+ const idx_t* ids,
259
+ float* simi,
260
+ idx_t* idxi,
261
+ size_t k) const override {
266
262
  size_t nup = 0;
267
263
  for (size_t j = 0; j < list_size; j++) {
264
+ float dis = hc.hamming(codes);
268
265
 
269
- float dis = hc.hamming (codes);
270
-
271
- if (dis < simi [0]) {
272
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
273
- maxheap_replace_top (k, simi, idxi, dis, id);
266
+ if (dis < simi[0]) {
267
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
268
+ maxheap_replace_top(k, simi, idxi, dis, id);
274
269
  nup++;
275
270
  }
276
271
  codes += code_size;
@@ -278,34 +273,31 @@ struct IVFScanner: InvertedListScanner {
278
273
  return nup;
279
274
  }
280
275
 
281
- void scan_codes_range (size_t list_size,
282
- const uint8_t *codes,
283
- const idx_t *ids,
284
- float radius,
285
- RangeQueryResult & res) const override
286
- {
276
+ void scan_codes_range(
277
+ size_t list_size,
278
+ const uint8_t* codes,
279
+ const idx_t* ids,
280
+ float radius,
281
+ RangeQueryResult& res) const override {
287
282
  for (size_t j = 0; j < list_size; j++) {
288
- float dis = hc.hamming (codes);
283
+ float dis = hc.hamming(codes);
289
284
  if (dis < radius) {
290
- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
291
- res.add (dis, id);
285
+ int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
286
+ res.add(dis, id);
292
287
  }
293
288
  codes += code_size;
294
289
  }
295
290
  }
296
-
297
-
298
291
  };
299
292
 
300
293
  } // anonymous namespace
301
294
 
302
- InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
303
- (bool store_pairs) const
304
- {
295
+ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
296
+ bool store_pairs) const {
305
297
  switch (code_size) {
306
298
  #define HANDLE_CODE_SIZE(cs) \
307
- case cs: \
308
- return new IVFScanner<HammingComputer ## cs> (this, store_pairs)
299
+ case cs: \
300
+ return new IVFScanner<HammingComputer##cs>(this, store_pairs)
309
301
  HANDLE_CODE_SIZE(4);
310
302
  HANDLE_CODE_SIZE(8);
311
303
  HANDLE_CODE_SIZE(16);
@@ -314,17 +306,38 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
314
306
  HANDLE_CODE_SIZE(64);
315
307
  #undef HANDLE_CODE_SIZE
316
308
  default:
317
- if (code_size % 8 == 0) {
318
- return new IVFScanner<HammingComputerM8>(this, store_pairs);
319
- } else if (code_size % 4 == 0) {
320
- return new IVFScanner<HammingComputerM4>(this, store_pairs);
321
- } else {
322
- FAISS_THROW_MSG("not supported");
323
- }
309
+ return new IVFScanner<HammingComputerDefault>(this, store_pairs);
324
310
  }
325
-
326
311
  }
327
312
 
313
+ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
314
+ FAISS_THROW_IF_NOT(vt_in->d_out == nbit);
315
+ FAISS_THROW_IF_NOT(vt_in->d_in == d);
316
+ if (own_fields) {
317
+ delete vt;
318
+ }
319
+ vt = vt_in;
320
+ threshold_type = Thresh_global;
321
+ is_trained = quantizer->is_trained && quantizer->ntotal == nlist &&
322
+ vt->is_trained;
323
+ own_fields = own;
324
+ }
328
325
 
326
+ /*
327
+ Check that the encoder is a single vector transform followed by a LSH
328
+ that just does thresholding.
329
+ If this is not the case, the linear transform + threhsolds of the IndexLSH
330
+ should be merged into the VectorTransform (which is feasible).
331
+ */
332
+
333
+ void IndexIVFSpectralHash::replace_vt(IndexPreTransform* encoder, bool own) {
334
+ FAISS_THROW_IF_NOT(encoder->chain.size() == 1);
335
+ auto sub_index = dynamic_cast<IndexLSH*>(encoder->index);
336
+ FAISS_THROW_IF_NOT_MSG(sub_index, "final index should be LSH");
337
+ FAISS_THROW_IF_NOT(sub_index->nbits == nbit);
338
+ FAISS_THROW_IF_NOT(!sub_index->rotate_data);
339
+ FAISS_THROW_IF_NOT(!sub_index->train_thresholds);
340
+ replace_vt(encoder->chain[0], own);
341
+ }
329
342
 
330
- } // namespace faiss
343
+ } // namespace faiss
@@ -10,15 +10,14 @@
10
10
  #ifndef FAISS_INDEX_IVFSH_H
11
11
  #define FAISS_INDEX_IVFSH_H
12
12
 
13
-
14
13
  #include <vector>
15
14
 
16
15
  #include <faiss/IndexIVF.h>
17
16
 
18
-
19
17
  namespace faiss {
20
18
 
21
19
  struct VectorTransform;
20
+ struct IndexPreTransform;
22
21
 
23
22
  /** Inverted list that stores binary codes of size nbit. Before the
24
23
  * binary conversion, the dimension of the vectors is transformed from
@@ -27,49 +26,63 @@ struct VectorTransform;
27
26
  * Each coordinate is subtracted from a value determined by
28
27
  * threshold_type, and split into intervals of size period. Half of
29
28
  * the interval is a 0 bit, the other half a 1.
29
+ *
30
30
  */
31
- struct IndexIVFSpectralHash: IndexIVF {
32
-
33
- VectorTransform *vt; // transformation from d to nbit dim
31
+ struct IndexIVFSpectralHash : IndexIVF {
32
+ /// transformation from d to nbit dim
33
+ VectorTransform* vt;
34
+ /// own the vt
34
35
  bool own_fields;
35
36
 
37
+ /// nb of bits of the binary signature
36
38
  int nbit;
39
+ /// interval size for 0s and 1s
37
40
  float period;
38
41
 
39
42
  enum ThresholdType {
40
- Thresh_global,
41
- Thresh_centroid,
42
- Thresh_centroid_half,
43
- Thresh_median
43
+ Thresh_global, ///< global threshold at 0
44
+ Thresh_centroid, ///< compare to centroid
45
+ Thresh_centroid_half, ///< central interval around centroid
46
+ Thresh_median ///< median of training set
44
47
  };
45
48
  ThresholdType threshold_type;
46
49
 
47
- // size nlist * nbit or 0 if Thresh_global
50
+ /// Trained threshold.
51
+ /// size nlist * nbit or 0 if Thresh_global
48
52
  std::vector<float> trained;
49
53
 
50
- IndexIVFSpectralHash (Index * quantizer, size_t d, size_t nlist,
51
- int nbit, float period);
54
+ IndexIVFSpectralHash(
55
+ Index* quantizer,
56
+ size_t d,
57
+ size_t nlist,
58
+ int nbit,
59
+ float period);
52
60
 
53
- IndexIVFSpectralHash ();
61
+ IndexIVFSpectralHash();
54
62
 
55
63
  void train_residual(idx_t n, const float* x) override;
56
64
 
57
- void encode_vectors(idx_t n, const float* x,
58
- const idx_t *list_nos,
59
- uint8_t * codes,
60
- bool include_listnos = false) const override;
65
+ void encode_vectors(
66
+ idx_t n,
67
+ const float* x,
68
+ const idx_t* list_nos,
69
+ uint8_t* codes,
70
+ bool include_listnos = false) const override;
61
71
 
62
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
63
- const override;
72
+ InvertedListScanner* get_InvertedListScanner(
73
+ bool store_pairs) const override;
64
74
 
65
- ~IndexIVFSpectralHash () override;
75
+ /** replace the vector transform for an empty (and possibly untrained) index
76
+ */
77
+ void replace_vt(VectorTransform* vt, bool own = false);
66
78
 
67
- };
79
+ /** convenience function to get the VT from an index constucted by an
80
+ * index_factory (should end in "LSH") */
81
+ void replace_vt(IndexPreTransform* index, bool own = false);
68
82
 
83
+ ~IndexIVFSpectralHash() override;
84
+ };
69
85
 
70
-
71
-
72
- }; // namespace faiss
73
-
86
+ } // namespace faiss
74
87
 
75
88
  #endif