faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,231 +9,181 @@
9
9
 
10
10
  #include <faiss/IndexScalarQuantizer.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <algorithm>
13
+ #include <cstdio>
14
14
 
15
15
  #include <omp.h>
16
16
 
17
- #include <faiss/utils/utils.h>
18
- #include <faiss/impl/FaissAssert.h>
19
17
  #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
20
19
  #include <faiss/impl/ScalarQuantizer.h>
20
+ #include <faiss/utils/utils.h>
21
21
 
22
22
  namespace faiss {
23
23
 
24
-
25
-
26
24
  /*******************************************************************
27
25
  * IndexScalarQuantizer implementation
28
26
  ********************************************************************/
29
27
 
30
- IndexScalarQuantizer::IndexScalarQuantizer
31
- (int d, ScalarQuantizer::QuantizerType qtype,
32
- MetricType metric):
33
- Index(d, metric),
34
- sq (d, qtype)
35
- {
36
- is_trained =
37
- qtype == ScalarQuantizer::QT_fp16 ||
38
- qtype == ScalarQuantizer::QT_8bit_direct;
28
+ IndexScalarQuantizer::IndexScalarQuantizer(
29
+ int d,
30
+ ScalarQuantizer::QuantizerType qtype,
31
+ MetricType metric)
32
+ : IndexFlatCodes(0, d, metric), sq(d, qtype) {
33
+ is_trained = qtype == ScalarQuantizer::QT_fp16 ||
34
+ qtype == ScalarQuantizer::QT_8bit_direct;
39
35
  code_size = sq.code_size;
40
36
  }
41
37
 
38
+ IndexScalarQuantizer::IndexScalarQuantizer()
39
+ : IndexScalarQuantizer(0, ScalarQuantizer::QT_8bit) {}
42
40
 
43
- IndexScalarQuantizer::IndexScalarQuantizer ():
44
- IndexScalarQuantizer(0, ScalarQuantizer::QT_8bit)
45
- {}
46
-
47
- void IndexScalarQuantizer::train(idx_t n, const float* x)
48
- {
41
+ void IndexScalarQuantizer::train(idx_t n, const float* x) {
49
42
  sq.train(n, x);
50
43
  is_trained = true;
51
44
  }
52
45
 
53
- void IndexScalarQuantizer::add(idx_t n, const float* x)
54
- {
55
- FAISS_THROW_IF_NOT (is_trained);
56
- codes.resize ((n + ntotal) * code_size);
57
- sq.compute_codes (x, &codes[ntotal * code_size], n);
58
- ntotal += n;
59
- }
60
-
61
-
62
46
  void IndexScalarQuantizer::search(
63
47
  idx_t n,
64
48
  const float* x,
65
49
  idx_t k,
66
50
  float* distances,
67
- idx_t* labels) const
68
- {
69
- FAISS_THROW_IF_NOT (is_trained);
70
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2 ||
71
- metric_type == METRIC_INNER_PRODUCT);
51
+ idx_t* labels) const {
52
+ FAISS_THROW_IF_NOT(k > 0);
53
+
54
+ FAISS_THROW_IF_NOT(is_trained);
55
+ FAISS_THROW_IF_NOT(
56
+ metric_type == METRIC_L2 || metric_type == METRIC_INNER_PRODUCT);
72
57
 
73
58
  #pragma omp parallel
74
59
  {
75
- InvertedListScanner* scanner = sq.select_InvertedListScanner
76
- (metric_type, nullptr, true);
60
+ InvertedListScanner* scanner =
61
+ sq.select_InvertedListScanner(metric_type, nullptr, true);
77
62
  ScopeDeleter1<InvertedListScanner> del(scanner);
63
+ scanner->list_no = 0; // directly the list number
78
64
 
79
65
  #pragma omp for
80
66
  for (idx_t i = 0; i < n; i++) {
81
- float * D = distances + k * i;
82
- idx_t * I = labels + k * i;
67
+ float* D = distances + k * i;
68
+ idx_t* I = labels + k * i;
83
69
  // re-order heap
84
70
  if (metric_type == METRIC_L2) {
85
- maxheap_heapify (k, D, I);
71
+ maxheap_heapify(k, D, I);
86
72
  } else {
87
- minheap_heapify (k, D, I);
73
+ minheap_heapify(k, D, I);
88
74
  }
89
- scanner->set_query (x + i * d);
90
- scanner->scan_codes (ntotal, codes.data(),
91
- nullptr, D, I, k);
75
+ scanner->set_query(x + i * d);
76
+ scanner->scan_codes(ntotal, codes.data(), nullptr, D, I, k);
92
77
 
93
78
  // re-order heap
94
79
  if (metric_type == METRIC_L2) {
95
- maxheap_reorder (k, D, I);
80
+ maxheap_reorder(k, D, I);
96
81
  } else {
97
- minheap_reorder (k, D, I);
82
+ minheap_reorder(k, D, I);
98
83
  }
99
84
  }
100
85
  }
101
-
102
86
  }
103
87
 
104
-
105
- DistanceComputer *IndexScalarQuantizer::get_distance_computer () const
106
- {
107
- ScalarQuantizer::SQDistanceComputer *dc =
108
- sq.get_distance_computer (metric_type);
88
+ DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
89
+ ScalarQuantizer::SQDistanceComputer* dc =
90
+ sq.get_distance_computer(metric_type);
109
91
  dc->code_size = sq.code_size;
110
92
  dc->codes = codes.data();
111
93
  return dc;
112
94
  }
113
95
 
114
-
115
- void IndexScalarQuantizer::reset()
116
- {
117
- codes.clear();
118
- ntotal = 0;
119
- }
120
-
121
- void IndexScalarQuantizer::reconstruct_n(
122
- idx_t i0, idx_t ni, float* recons) const
123
- {
124
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer ());
125
- for (size_t i = 0; i < ni; i++) {
126
- squant->decode_vector(&codes[(i + i0) * code_size], recons + i * d);
127
- }
128
- }
129
-
130
- void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const
131
- {
132
- reconstruct_n(key, 1, recons);
133
- }
134
-
135
96
  /* Codec interface */
136
- size_t IndexScalarQuantizer::sa_code_size () const
137
- {
138
- return sq.code_size;
139
- }
140
97
 
141
- void IndexScalarQuantizer::sa_encode (idx_t n, const float *x,
142
- uint8_t *bytes) const
143
- {
144
- FAISS_THROW_IF_NOT (is_trained);
145
- sq.compute_codes (x, bytes, n);
98
+ void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
99
+ const {
100
+ FAISS_THROW_IF_NOT(is_trained);
101
+ sq.compute_codes(x, bytes, n);
146
102
  }
147
103
 
148
- void IndexScalarQuantizer::sa_decode (idx_t n, const uint8_t *bytes,
149
- float *x) const
150
- {
151
- FAISS_THROW_IF_NOT (is_trained);
104
+ void IndexScalarQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
105
+ const {
106
+ FAISS_THROW_IF_NOT(is_trained);
152
107
  sq.decode(bytes, x, n);
153
108
  }
154
109
 
155
-
156
-
157
110
  /*******************************************************************
158
111
  * IndexIVFScalarQuantizer implementation
159
112
  ********************************************************************/
160
113
 
161
- IndexIVFScalarQuantizer::IndexIVFScalarQuantizer (
162
- Index *quantizer, size_t d, size_t nlist,
163
- ScalarQuantizer::QuantizerType qtype,
164
- MetricType metric, bool encode_residual)
165
- : IndexIVF(quantizer, d, nlist, 0, metric),
166
- sq(d, qtype),
167
- by_residual(encode_residual)
168
- {
114
+ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
115
+ Index* quantizer,
116
+ size_t d,
117
+ size_t nlist,
118
+ ScalarQuantizer::QuantizerType qtype,
119
+ MetricType metric,
120
+ bool encode_residual)
121
+ : IndexIVF(quantizer, d, nlist, 0, metric),
122
+ sq(d, qtype),
123
+ by_residual(encode_residual) {
169
124
  code_size = sq.code_size;
170
125
  // was not known at construction time
171
126
  invlists->code_size = code_size;
172
127
  is_trained = false;
173
128
  }
174
129
 
175
- IndexIVFScalarQuantizer::IndexIVFScalarQuantizer ():
176
- IndexIVF(),
177
- by_residual(true)
178
- {
179
- }
130
+ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer()
131
+ : IndexIVF(), by_residual(true) {}
180
132
 
181
- void IndexIVFScalarQuantizer::train_residual (idx_t n, const float *x)
182
- {
133
+ void IndexIVFScalarQuantizer::train_residual(idx_t n, const float* x) {
183
134
  sq.train_residual(n, x, quantizer, by_residual, verbose);
184
135
  }
185
136
 
186
- void IndexIVFScalarQuantizer::encode_vectors(idx_t n, const float* x,
187
- const idx_t *list_nos,
188
- uint8_t * codes,
189
- bool include_listnos) const
190
- {
191
- std::unique_ptr<ScalarQuantizer::Quantizer> squant (sq.select_quantizer ());
192
- size_t coarse_size = include_listnos ? coarse_code_size () : 0;
137
+ void IndexIVFScalarQuantizer::encode_vectors(
138
+ idx_t n,
139
+ const float* x,
140
+ const idx_t* list_nos,
141
+ uint8_t* codes,
142
+ bool include_listnos) const {
143
+ std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
144
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
193
145
  memset(codes, 0, (code_size + coarse_size) * n);
194
146
 
195
- #pragma omp parallel if(n > 1000)
147
+ #pragma omp parallel if (n > 1000)
196
148
  {
197
- std::vector<float> residual (d);
149
+ std::vector<float> residual(d);
198
150
 
199
151
  #pragma omp for
200
152
  for (idx_t i = 0; i < n; i++) {
201
- int64_t list_no = list_nos [i];
153
+ int64_t list_no = list_nos[i];
202
154
  if (list_no >= 0) {
203
- const float *xi = x + i * d;
204
- uint8_t *code = codes + i * (code_size + coarse_size);
155
+ const float* xi = x + i * d;
156
+ uint8_t* code = codes + i * (code_size + coarse_size);
205
157
  if (by_residual) {
206
- quantizer->compute_residual (
207
- xi, residual.data(), list_no);
208
- xi = residual.data ();
158
+ quantizer->compute_residual(xi, residual.data(), list_no);
159
+ xi = residual.data();
209
160
  }
210
161
  if (coarse_size) {
211
- encode_listno (list_no, code);
162
+ encode_listno(list_no, code);
212
163
  }
213
- squant->encode_vector (xi, code + coarse_size);
164
+ squant->encode_vector(xi, code + coarse_size);
214
165
  }
215
166
  }
216
167
  }
217
168
  }
218
169
 
219
- void IndexIVFScalarQuantizer::sa_decode (idx_t n, const uint8_t *codes,
220
- float *x) const
221
- {
222
- std::unique_ptr<ScalarQuantizer::Quantizer> squant (sq.select_quantizer ());
223
- size_t coarse_size = coarse_code_size ();
170
+ void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x)
171
+ const {
172
+ std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
173
+ size_t coarse_size = coarse_code_size();
224
174
 
225
- #pragma omp parallel if(n > 1000)
175
+ #pragma omp parallel if (n > 1000)
226
176
  {
227
- std::vector<float> residual (d);
177
+ std::vector<float> residual(d);
228
178
 
229
179
  #pragma omp for
230
180
  for (idx_t i = 0; i < n; i++) {
231
- const uint8_t *code = codes + i * (code_size + coarse_size);
232
- int64_t list_no = decode_listno (code);
233
- float *xi = x + i * d;
234
- squant->decode_vector (code + coarse_size, xi);
181
+ const uint8_t* code = codes + i * (code_size + coarse_size);
182
+ int64_t list_no = decode_listno(code);
183
+ float* xi = x + i * d;
184
+ squant->decode_vector(code + coarse_size, xi);
235
185
  if (by_residual) {
236
- quantizer->reconstruct (list_no, residual.data());
186
+ quantizer->reconstruct(list_no, residual.data());
237
187
  for (size_t j = 0; j < d; j++) {
238
188
  xi[j] += residual[j];
239
189
  }
@@ -242,83 +192,72 @@ void IndexIVFScalarQuantizer::sa_decode (idx_t n, const uint8_t *codes,
242
192
  }
243
193
  }
244
194
 
195
+ void IndexIVFScalarQuantizer::add_core(
196
+ idx_t n,
197
+ const float* x,
198
+ const idx_t* xids,
199
+ const idx_t* coarse_idx) {
200
+ FAISS_THROW_IF_NOT(is_trained);
245
201
 
246
-
247
- void IndexIVFScalarQuantizer::add_with_ids
248
- (idx_t n, const float * x, const idx_t *xids)
249
- {
250
- FAISS_THROW_IF_NOT (is_trained);
251
- std::unique_ptr<int64_t []> idx (new int64_t [n]);
252
- quantizer->assign (n, x, idx.get());
253
202
  size_t nadd = 0;
254
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer ());
203
+ std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
255
204
 
256
- DirectMapAdd dm_add (direct_map, n, xids);
205
+ DirectMapAdd dm_add(direct_map, n, xids);
257
206
 
258
- #pragma omp parallel reduction(+: nadd)
207
+ #pragma omp parallel reduction(+ : nadd)
259
208
  {
260
- std::vector<float> residual (d);
261
- std::vector<uint8_t> one_code (code_size);
209
+ std::vector<float> residual(d);
210
+ std::vector<uint8_t> one_code(code_size);
262
211
  int nt = omp_get_num_threads();
263
212
  int rank = omp_get_thread_num();
264
213
 
265
214
  // each thread takes care of a subset of lists
266
215
  for (size_t i = 0; i < n; i++) {
267
- int64_t list_no = idx [i];
216
+ int64_t list_no = coarse_idx[i];
268
217
  if (list_no >= 0 && list_no % nt == rank) {
269
218
  int64_t id = xids ? xids[i] : ntotal + i;
270
219
 
271
- const float * xi = x + i * d;
220
+ const float* xi = x + i * d;
272
221
  if (by_residual) {
273
- quantizer->compute_residual (xi, residual.data(), list_no);
222
+ quantizer->compute_residual(xi, residual.data(), list_no);
274
223
  xi = residual.data();
275
224
  }
276
225
 
277
- memset (one_code.data(), 0, code_size);
278
- squant->encode_vector (xi, one_code.data());
226
+ memset(one_code.data(), 0, code_size);
227
+ squant->encode_vector(xi, one_code.data());
279
228
 
280
- size_t ofs = invlists->add_entry (list_no, id, one_code.data());
229
+ size_t ofs = invlists->add_entry(list_no, id, one_code.data());
281
230
 
282
- dm_add.add (i, list_no, ofs);
231
+ dm_add.add(i, list_no, ofs);
283
232
  nadd++;
284
233
 
285
234
  } else if (rank == 0 && list_no == -1) {
286
- dm_add.add (i, -1, 0);
235
+ dm_add.add(i, -1, 0);
287
236
  }
288
237
  }
289
238
  }
290
239
 
291
-
292
240
  ntotal += n;
293
241
  }
294
242
 
295
-
296
-
297
-
298
-
299
- InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner
300
- (bool store_pairs) const
301
- {
302
- return sq.select_InvertedListScanner (metric_type, quantizer, store_pairs,
303
- by_residual);
243
+ InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner(
244
+ bool store_pairs) const {
245
+ return sq.select_InvertedListScanner(
246
+ metric_type, quantizer, store_pairs, by_residual);
304
247
  }
305
248
 
306
-
307
- void IndexIVFScalarQuantizer::reconstruct_from_offset (int64_t list_no,
308
- int64_t offset,
309
- float* recons) const
310
- {
249
+ void IndexIVFScalarQuantizer::reconstruct_from_offset(
250
+ int64_t list_no,
251
+ int64_t offset,
252
+ float* recons) const {
311
253
  std::vector<float> centroid(d);
312
- quantizer->reconstruct (list_no, centroid.data());
254
+ quantizer->reconstruct(list_no, centroid.data());
313
255
 
314
- const uint8_t* code = invlists->get_single_code (list_no, offset);
315
- sq.decode (code, recons, 1);
256
+ const uint8_t* code = invlists->get_single_code(list_no, offset);
257
+ sq.decode(code, recons, 1);
316
258
  for (int i = 0; i < d; ++i) {
317
259
  recons[i] += centroid[i];
318
260
  }
319
261
  }
320
262
 
321
-
322
-
323
-
324
263
  } // namespace faiss
@@ -13,10 +13,10 @@
13
13
  #include <stdint.h>
14
14
  #include <vector>
15
15
 
16
+ #include <faiss/IndexFlatCodes.h>
16
17
  #include <faiss/IndexIVF.h>
17
18
  #include <faiss/impl/ScalarQuantizer.h>
18
19
 
19
-
20
20
  namespace faiss {
21
21
 
22
22
  /**
@@ -25,103 +25,85 @@ namespace faiss {
25
25
  * (default).
26
26
  */
27
27
 
28
-
29
-
30
-
31
- struct IndexScalarQuantizer: Index {
28
+ struct IndexScalarQuantizer : IndexFlatCodes {
32
29
  /// Used to encode the vectors
33
30
  ScalarQuantizer sq;
34
31
 
35
- /// Codes. Size ntotal * pq.code_size
36
- std::vector<uint8_t> codes;
37
-
38
- size_t code_size;
39
-
40
32
  /** Constructor.
41
33
  *
42
34
  * @param d dimensionality of the input vectors
43
35
  * @param M number of subquantizers
44
36
  * @param nbits number of bit per subvector index
45
37
  */
46
- IndexScalarQuantizer (int d,
47
- ScalarQuantizer::QuantizerType qtype,
48
- MetricType metric = METRIC_L2);
38
+ IndexScalarQuantizer(
39
+ int d,
40
+ ScalarQuantizer::QuantizerType qtype,
41
+ MetricType metric = METRIC_L2);
49
42
 
50
- IndexScalarQuantizer ();
43
+ IndexScalarQuantizer();
51
44
 
52
45
  void train(idx_t n, const float* x) override;
53
46
 
54
- void add(idx_t n, const float* x) override;
55
-
56
47
  void search(
57
- idx_t n,
58
- const float* x,
59
- idx_t k,
60
- float* distances,
61
- idx_t* labels) const override;
62
-
63
- void reset() override;
64
-
65
- void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
48
+ idx_t n,
49
+ const float* x,
50
+ idx_t k,
51
+ float* distances,
52
+ idx_t* labels) const override;
66
53
 
67
- void reconstruct(idx_t key, float* recons) const override;
68
-
69
- DistanceComputer *get_distance_computer () const override;
54
+ DistanceComputer* get_distance_computer() const override;
70
55
 
71
56
  /* standalone codec interface */
72
- size_t sa_code_size () const override;
73
-
74
- void sa_encode (idx_t n, const float *x,
75
- uint8_t *bytes) const override;
76
-
77
- void sa_decode (idx_t n, const uint8_t *bytes,
78
- float *x) const override;
79
-
57
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
80
58
 
59
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
81
60
  };
82
61
 
83
-
84
- /** An IVF implementation where the components of the residuals are
62
+ /** An IVF implementation where the components of the residuals are
85
63
  * encoded with a scalar quantizer. All distance computations
86
64
  * are asymmetric, so the encoded vectors are decoded and approximate
87
65
  * distances are computed.
88
66
  */
89
67
 
90
- struct IndexIVFScalarQuantizer: IndexIVF {
68
+ struct IndexIVFScalarQuantizer : IndexIVF {
91
69
  ScalarQuantizer sq;
92
70
  bool by_residual;
93
71
 
94
- IndexIVFScalarQuantizer(Index *quantizer, size_t d, size_t nlist,
95
- ScalarQuantizer::QuantizerType qtype,
96
- MetricType metric = METRIC_L2,
97
- bool encode_residual = true);
72
+ IndexIVFScalarQuantizer(
73
+ Index* quantizer,
74
+ size_t d,
75
+ size_t nlist,
76
+ ScalarQuantizer::QuantizerType qtype,
77
+ MetricType metric = METRIC_L2,
78
+ bool encode_residual = true);
98
79
 
99
80
  IndexIVFScalarQuantizer();
100
81
 
101
82
  void train_residual(idx_t n, const float* x) override;
102
83
 
103
- void encode_vectors(idx_t n, const float* x,
104
- const idx_t *list_nos,
105
- uint8_t * codes,
106
- bool include_listnos=false) const override;
84
+ void encode_vectors(
85
+ idx_t n,
86
+ const float* x,
87
+ const idx_t* list_nos,
88
+ uint8_t* codes,
89
+ bool include_listnos = false) const override;
107
90
 
108
- void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
91
+ void add_core(
92
+ idx_t n,
93
+ const float* x,
94
+ const idx_t* xids,
95
+ const idx_t* precomputed_idx) override;
109
96
 
110
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
111
- const override;
97
+ InvertedListScanner* get_InvertedListScanner(
98
+ bool store_pairs) const override;
112
99
 
113
-
114
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
115
- float* recons) const override;
100
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
101
+ const override;
116
102
 
117
103
  /* standalone codec interface */
118
- void sa_decode (idx_t n, const uint8_t *bytes,
119
- float *x) const override;
120
-
104
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
121
105
  };
122
106
 
123
-
124
- }
125
-
107
+ } // namespace faiss
126
108
 
127
109
  #endif