faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,132 +9,88 @@
9
9
 
10
10
  #include <faiss/IndexFlat.h>
11
11
 
12
- #include <cstring>
12
+ #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/utils/Heap.h>
13
15
  #include <faiss/utils/distances.h>
14
16
  #include <faiss/utils/extra_distances.h>
15
17
  #include <faiss/utils/utils.h>
16
- #include <faiss/utils/Heap.h>
17
- #include <faiss/impl/FaissAssert.h>
18
- #include <faiss/impl/AuxIndexStructures.h>
19
-
18
+ #include <cstring>
20
19
 
21
20
  namespace faiss {
22
21
 
23
- IndexFlat::IndexFlat (idx_t d, MetricType metric):
24
- Index(d, metric)
25
- {
26
- }
27
-
22
+ IndexFlat::IndexFlat(idx_t d, MetricType metric)
23
+ : IndexFlatCodes(sizeof(float) * d, d, metric) {}
28
24
 
25
+ void IndexFlat::search(
26
+ idx_t n,
27
+ const float* x,
28
+ idx_t k,
29
+ float* distances,
30
+ idx_t* labels) const {
31
+ FAISS_THROW_IF_NOT(k > 0);
29
32
 
30
- void IndexFlat::add (idx_t n, const float *x) {
31
- xb.insert(xb.end(), x, x + n * d);
32
- ntotal += n;
33
- }
34
-
35
-
36
- void IndexFlat::reset() {
37
- xb.clear();
38
- ntotal = 0;
39
- }
40
-
41
-
42
- void IndexFlat::search (idx_t n, const float *x, idx_t k,
43
- float *distances, idx_t *labels) const
44
- {
45
33
  // we see the distances and labels as heaps
46
34
 
47
35
  if (metric_type == METRIC_INNER_PRODUCT) {
48
- float_minheap_array_t res = {
49
- size_t(n), size_t(k), labels, distances};
50
- knn_inner_product (x, xb.data(), d, n, ntotal, &res);
36
+ float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
37
+ knn_inner_product(x, get_xb(), d, n, ntotal, &res);
51
38
  } else if (metric_type == METRIC_L2) {
52
- float_maxheap_array_t res = {
53
- size_t(n), size_t(k), labels, distances};
54
- knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
39
+ float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
40
+ knn_L2sqr(x, get_xb(), d, n, ntotal, &res);
55
41
  } else {
56
- float_maxheap_array_t res = {
57
- size_t(n), size_t(k), labels, distances};
58
- knn_extra_metrics (x, xb.data(), d, n, ntotal,
59
- metric_type, metric_arg,
60
- &res);
42
+ float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
43
+ knn_extra_metrics(
44
+ x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
61
45
  }
62
46
  }
63
47
 
64
- void IndexFlat::range_search (idx_t n, const float *x, float radius,
65
- RangeSearchResult *result) const
66
- {
48
+ void IndexFlat::range_search(
49
+ idx_t n,
50
+ const float* x,
51
+ float radius,
52
+ RangeSearchResult* result) const {
67
53
  switch (metric_type) {
68
- case METRIC_INNER_PRODUCT:
69
- range_search_inner_product (x, xb.data(), d, n, ntotal,
70
- radius, result);
71
- break;
72
- case METRIC_L2:
73
- range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
74
- break;
75
- default:
76
- FAISS_THROW_MSG("metric type not supported");
54
+ case METRIC_INNER_PRODUCT:
55
+ range_search_inner_product(
56
+ x, get_xb(), d, n, ntotal, radius, result);
57
+ break;
58
+ case METRIC_L2:
59
+ range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result);
60
+ break;
61
+ default:
62
+ FAISS_THROW_MSG("metric type not supported");
77
63
  }
78
64
  }
79
65
 
80
-
81
- void IndexFlat::compute_distance_subset (
82
- idx_t n,
83
- const float *x,
84
- idx_t k,
85
- float *distances,
86
- const idx_t *labels) const
87
- {
66
+ void IndexFlat::compute_distance_subset(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ const idx_t* labels) const {
88
72
  switch (metric_type) {
89
73
  case METRIC_INNER_PRODUCT:
90
- fvec_inner_products_by_idx (
91
- distances,
92
- x, xb.data(), labels, d, n, k);
74
+ fvec_inner_products_by_idx(distances, x, get_xb(), labels, d, n, k);
93
75
  break;
94
76
  case METRIC_L2:
95
- fvec_L2sqr_by_idx (
96
- distances,
97
- x, xb.data(), labels, d, n, k);
77
+ fvec_L2sqr_by_idx(distances, x, get_xb(), labels, d, n, k);
98
78
  break;
99
79
  default:
100
80
  FAISS_THROW_MSG("metric type not supported");
101
81
  }
102
-
103
82
  }
104
83
 
105
- size_t IndexFlat::remove_ids (const IDSelector & sel)
106
- {
107
- idx_t j = 0;
108
- for (idx_t i = 0; i < ntotal; i++) {
109
- if (sel.is_member (i)) {
110
- // should be removed
111
- } else {
112
- if (i > j) {
113
- memmove (&xb[d * j], &xb[d * i], sizeof(xb[0]) * d);
114
- }
115
- j++;
116
- }
117
- }
118
- size_t nremove = ntotal - j;
119
- if (nremove > 0) {
120
- ntotal = j;
121
- xb.resize (ntotal * d);
122
- }
123
- return nremove;
124
- }
125
-
126
-
127
84
  namespace {
128
85
 
129
-
130
86
  struct FlatL2Dis : DistanceComputer {
131
87
  size_t d;
132
88
  Index::idx_t nb;
133
- const float *q;
134
- const float *b;
89
+ const float* q;
90
+ const float* b;
135
91
  size_t ndis;
136
92
 
137
- float operator () (idx_t i) override {
93
+ float operator()(idx_t i) override {
138
94
  ndis++;
139
95
  return fvec_L2sqr(q, b + i * d, d);
140
96
  }
@@ -143,14 +99,14 @@ struct FlatL2Dis : DistanceComputer {
143
99
  return fvec_L2sqr(b + j * d, b + i * d, d);
144
100
  }
145
101
 
146
- explicit FlatL2Dis(const IndexFlat& storage, const float *q = nullptr)
147
- : d(storage.d),
148
- nb(storage.ntotal),
149
- q(q),
150
- b(storage.xb.data()),
151
- ndis(0) {}
102
+ explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
103
+ : d(storage.d),
104
+ nb(storage.ntotal),
105
+ q(q),
106
+ b(storage.get_xb()),
107
+ ndis(0) {}
152
108
 
153
- void set_query(const float *x) override {
109
+ void set_query(const float* x) override {
154
110
  q = x;
155
111
  }
156
112
  };
@@ -158,128 +114,106 @@ struct FlatL2Dis : DistanceComputer {
158
114
  struct FlatIPDis : DistanceComputer {
159
115
  size_t d;
160
116
  Index::idx_t nb;
161
- const float *q;
162
- const float *b;
117
+ const float* q;
118
+ const float* b;
163
119
  size_t ndis;
164
120
 
165
- float operator () (idx_t i) override {
121
+ float operator()(idx_t i) override {
166
122
  ndis++;
167
- return fvec_inner_product (q, b + i * d, d);
123
+ return fvec_inner_product(q, b + i * d, d);
168
124
  }
169
125
 
170
126
  float symmetric_dis(idx_t i, idx_t j) override {
171
- return fvec_inner_product (b + j * d, b + i * d, d);
127
+ return fvec_inner_product(b + j * d, b + i * d, d);
172
128
  }
173
129
 
174
- explicit FlatIPDis(const IndexFlat& storage, const float *q = nullptr)
175
- : d(storage.d),
176
- nb(storage.ntotal),
177
- q(q),
178
- b(storage.xb.data()),
179
- ndis(0) {}
130
+ explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
131
+ : d(storage.d),
132
+ nb(storage.ntotal),
133
+ q(q),
134
+ b(storage.get_xb()),
135
+ ndis(0) {}
180
136
 
181
- void set_query(const float *x) override {
137
+ void set_query(const float* x) override {
182
138
  q = x;
183
139
  }
184
140
  };
185
141
 
142
+ } // namespace
186
143
 
187
-
188
-
189
- } // namespace
190
-
191
-
192
- DistanceComputer * IndexFlat::get_distance_computer() const {
144
+ DistanceComputer* IndexFlat::get_distance_computer() const {
193
145
  if (metric_type == METRIC_L2) {
194
146
  return new FlatL2Dis(*this);
195
147
  } else if (metric_type == METRIC_INNER_PRODUCT) {
196
148
  return new FlatIPDis(*this);
197
149
  } else {
198
- return get_extra_distance_computer (d, metric_type, metric_arg,
199
- ntotal, xb.data());
150
+ return get_extra_distance_computer(
151
+ d, metric_type, metric_arg, ntotal, get_xb());
200
152
  }
201
153
  }
202
154
 
203
-
204
- void IndexFlat::reconstruct (idx_t key, float * recons) const
205
- {
206
- memcpy (recons, &(xb[key * d]), sizeof(*recons) * d);
155
+ void IndexFlat::reconstruct(idx_t key, float* recons) const {
156
+ memcpy(recons, &(codes[key * code_size]), code_size);
207
157
  }
208
158
 
209
-
210
- /* The standalone codec interface */
211
- size_t IndexFlat::sa_code_size () const
212
- {
213
- return sizeof(float) * d;
214
- }
215
-
216
- void IndexFlat::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
217
- {
218
- memcpy (bytes, x, sizeof(float) * d * n);
159
+ void IndexFlat::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
160
+ if (n > 0) {
161
+ memcpy(bytes, x, sizeof(float) * d * n);
162
+ }
219
163
  }
220
164
 
221
- void IndexFlat::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
222
- {
223
- memcpy (x, bytes, sizeof(float) * d * n);
165
+ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
166
+ if (n > 0) {
167
+ memcpy(x, bytes, sizeof(float) * d * n);
168
+ }
224
169
  }
225
170
 
226
-
227
-
228
-
229
-
230
-
231
171
  /***************************************************
232
172
  * IndexFlat1D
233
173
  ***************************************************/
234
174
 
235
-
236
- IndexFlat1D::IndexFlat1D (bool continuous_update):
237
- IndexFlatL2 (1),
238
- continuous_update (continuous_update)
239
- {
240
- }
175
+ IndexFlat1D::IndexFlat1D(bool continuous_update)
176
+ : IndexFlatL2(1), continuous_update(continuous_update) {}
241
177
 
242
178
  /// if not continuous_update, call this between the last add and
243
179
  /// the first search
244
- void IndexFlat1D::update_permutation ()
245
- {
246
- perm.resize (ntotal);
180
+ void IndexFlat1D::update_permutation() {
181
+ perm.resize(ntotal);
247
182
  if (ntotal < 1000000) {
248
- fvec_argsort (ntotal, xb.data(), (size_t*)perm.data());
183
+ fvec_argsort(ntotal, get_xb(), (size_t*)perm.data());
249
184
  } else {
250
- fvec_argsort_parallel (ntotal, xb.data(), (size_t*)perm.data());
185
+ fvec_argsort_parallel(ntotal, get_xb(), (size_t*)perm.data());
251
186
  }
252
187
  }
253
188
 
254
- void IndexFlat1D::add (idx_t n, const float *x)
255
- {
256
- IndexFlatL2::add (n, x);
189
+ void IndexFlat1D::add(idx_t n, const float* x) {
190
+ IndexFlatL2::add(n, x);
257
191
  if (continuous_update)
258
192
  update_permutation();
259
193
  }
260
194
 
261
- void IndexFlat1D::reset()
262
- {
195
+ void IndexFlat1D::reset() {
263
196
  IndexFlatL2::reset();
264
197
  perm.clear();
265
198
  }
266
199
 
267
- void IndexFlat1D::search (
268
- idx_t n,
269
- const float *x,
270
- idx_t k,
271
- float *distances,
272
- idx_t *labels) const
273
- {
274
- FAISS_THROW_IF_NOT_MSG (perm.size() == ntotal,
275
- "Call update_permutation before search");
200
+ void IndexFlat1D::search(
201
+ idx_t n,
202
+ const float* x,
203
+ idx_t k,
204
+ float* distances,
205
+ idx_t* labels) const {
206
+ FAISS_THROW_IF_NOT(k > 0);
207
+
208
+ FAISS_THROW_IF_NOT_MSG(
209
+ perm.size() == ntotal, "Call update_permutation before search");
210
+ const float* xb = get_xb();
276
211
 
277
212
  #pragma omp parallel for
278
213
  for (idx_t i = 0; i < n; i++) {
279
-
280
214
  float q = x[i]; // query
281
- float *D = distances + i * k;
282
- idx_t *I = labels + i * k;
215
+ float* D = distances + i * k;
216
+ idx_t* I = labels + i * k;
283
217
 
284
218
  // binary search
285
219
  idx_t i0 = 0, i1 = ntotal;
@@ -297,8 +231,10 @@ void IndexFlat1D::search (
297
231
 
298
232
  while (i0 + 1 < i1) {
299
233
  idx_t imed = (i0 + i1) / 2;
300
- if (xb[perm[imed]] <= q) i0 = imed;
301
- else i1 = imed;
234
+ if (xb[perm[imed]] <= q)
235
+ i0 = imed;
236
+ else
237
+ i1 = imed;
302
238
  }
303
239
 
304
240
  // query is between xb[perm[i0]] and xb[perm[i1]]
@@ -311,13 +247,19 @@ void IndexFlat1D::search (
311
247
  if (q - xleft < xright - q) {
312
248
  D[wp] = q - xleft;
313
249
  I[wp] = perm[i0];
314
- i0--; wp++;
315
- if (i0 < 0) { goto finish_right; }
250
+ i0--;
251
+ wp++;
252
+ if (i0 < 0) {
253
+ goto finish_right;
254
+ }
316
255
  } else {
317
256
  D[wp] = xright - q;
318
257
  I[wp] = perm[i1];
319
- i1++; wp++;
320
- if (i1 >= ntotal) { goto finish_left; }
258
+ i1++;
259
+ wp++;
260
+ if (i1 >= ntotal) {
261
+ goto finish_left;
262
+ }
321
263
  }
322
264
  }
323
265
  goto done;
@@ -350,11 +292,8 @@ void IndexFlat1D::search (
350
292
  }
351
293
  wp++;
352
294
  }
353
- done: ;
295
+ done:;
354
296
  }
355
-
356
297
  }
357
298
 
358
-
359
-
360
299
  } // namespace faiss
@@ -12,35 +12,26 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
16
-
15
+ #include <faiss/IndexFlatCodes.h>
17
16
 
18
17
  namespace faiss {
19
18
 
20
19
  /** Index that stores the full vectors and performs exhaustive search */
21
- struct IndexFlat: Index {
22
-
23
- /// database vectors, size ntotal * d
24
- std::vector<float> xb;
25
-
26
- explicit IndexFlat (idx_t d, MetricType metric = METRIC_L2);
27
-
28
- void add(idx_t n, const float* x) override;
29
-
30
- void reset() override;
20
+ struct IndexFlat : IndexFlatCodes {
21
+ explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
31
22
 
32
23
  void search(
33
- idx_t n,
34
- const float* x,
35
- idx_t k,
36
- float* distances,
37
- idx_t* labels) const override;
24
+ idx_t n,
25
+ const float* x,
26
+ idx_t k,
27
+ float* distances,
28
+ idx_t* labels) const override;
38
29
 
39
30
  void range_search(
40
- idx_t n,
41
- const float* x,
42
- float radius,
43
- RangeSearchResult* result) const override;
31
+ idx_t n,
32
+ const float* x,
33
+ float radius,
34
+ RangeSearchResult* result) const override;
44
35
 
45
36
  void reconstruct(idx_t key, float* recons) const override;
46
37
 
@@ -52,59 +43,52 @@ struct IndexFlat: Index {
52
43
  * @param distances
53
44
  * corresponding output distances, size n * k
54
45
  */
55
- void compute_distance_subset (
46
+ void compute_distance_subset(
56
47
  idx_t n,
57
- const float *x,
48
+ const float* x,
58
49
  idx_t k,
59
- float *distances,
60
- const idx_t *labels) const;
50
+ float* distances,
51
+ const idx_t* labels) const;
61
52
 
62
- /** remove some ids. NB that Because of the structure of the
63
- * indexing structure, the semantics of this operation are
64
- * different from the usual ones: the new ids are shifted */
65
- size_t remove_ids(const IDSelector& sel) override;
53
+ // get pointer to the floating point data
54
+ float* get_xb() {
55
+ return (float*)codes.data();
56
+ }
57
+ const float* get_xb() const {
58
+ return (const float*)codes.data();
59
+ }
66
60
 
67
- IndexFlat () {}
61
+ IndexFlat() {}
68
62
 
69
- DistanceComputer * get_distance_computer() const override;
63
+ DistanceComputer* get_distance_computer() const override;
70
64
 
71
65
  /* The stanadlone codec interface (just memcopies in this case) */
72
- size_t sa_code_size () const override;
73
-
74
- void sa_encode (idx_t n, const float *x,
75
- uint8_t *bytes) const override;
76
-
77
- void sa_decode (idx_t n, const uint8_t *bytes,
78
- float *x) const override;
66
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
79
67
 
68
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
80
69
  };
81
70
 
82
-
83
-
84
- struct IndexFlatIP:IndexFlat {
85
- explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {}
86
- IndexFlatIP () {}
71
+ struct IndexFlatIP : IndexFlat {
72
+ explicit IndexFlatIP(idx_t d) : IndexFlat(d, METRIC_INNER_PRODUCT) {}
73
+ IndexFlatIP() {}
87
74
  };
88
75
 
89
-
90
- struct IndexFlatL2:IndexFlat {
91
- explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {}
92
- IndexFlatL2 () {}
76
+ struct IndexFlatL2 : IndexFlat {
77
+ explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
78
+ IndexFlatL2() {}
93
79
  };
94
80
 
95
-
96
-
97
81
  /// optimized version for 1D "vectors".
98
- struct IndexFlat1D:IndexFlatL2 {
82
+ struct IndexFlat1D : IndexFlatL2 {
99
83
  bool continuous_update; ///< is the permutation updated continuously?
100
84
 
101
85
  std::vector<idx_t> perm; ///< sorted database indices
102
86
 
103
- explicit IndexFlat1D (bool continuous_update=true);
87
+ explicit IndexFlat1D(bool continuous_update = true);
104
88
 
105
89
  /// if not continuous_update, call this between the last add and
106
90
  /// the first search
107
- void update_permutation ();
91
+ void update_permutation();
108
92
 
109
93
  void add(idx_t n, const float* x) override;
110
94
 
@@ -112,14 +96,13 @@ struct IndexFlat1D:IndexFlatL2 {
112
96
 
113
97
  /// Warn: the distances returned are L1 not L2
114
98
  void search(
115
- idx_t n,
116
- const float* x,
117
- idx_t k,
118
- float* distances,
119
- idx_t* labels) const override;
99
+ idx_t n,
100
+ const float* x,
101
+ idx_t k,
102
+ float* distances,
103
+ idx_t* labels) const override;
120
104
  };
121
105
 
122
-
123
- }
106
+ } // namespace faiss
124
107
 
125
108
  #endif
@@ -0,0 +1,67 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexFlatCodes.h>
9
+
10
+ #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/FaissAssert.h>
12
+
13
+ namespace faiss {
14
+
15
+ IndexFlatCodes::IndexFlatCodes(size_t code_size, idx_t d, MetricType metric)
16
+ : Index(d, metric), code_size(code_size) {}
17
+
18
+ IndexFlatCodes::IndexFlatCodes() : code_size(0) {}
19
+
20
+ void IndexFlatCodes::add(idx_t n, const float* x) {
21
+ FAISS_THROW_IF_NOT(is_trained);
22
+ codes.resize((ntotal + n) * code_size);
23
+ sa_encode(n, x, &codes[ntotal * code_size]);
24
+ ntotal += n;
25
+ }
26
+
27
+ void IndexFlatCodes::reset() {
28
+ codes.clear();
29
+ ntotal = 0;
30
+ }
31
+
32
+ size_t IndexFlatCodes::sa_code_size() const {
33
+ return code_size;
34
+ }
35
+
36
+ size_t IndexFlatCodes::remove_ids(const IDSelector& sel) {
37
+ idx_t j = 0;
38
+ for (idx_t i = 0; i < ntotal; i++) {
39
+ if (sel.is_member(i)) {
40
+ // should be removed
41
+ } else {
42
+ if (i > j) {
43
+ memmove(&codes[code_size * j],
44
+ &codes[code_size * i],
45
+ code_size);
46
+ }
47
+ j++;
48
+ }
49
+ }
50
+ size_t nremove = ntotal - j;
51
+ if (nremove > 0) {
52
+ ntotal = j;
53
+ codes.resize(ntotal * code_size);
54
+ }
55
+ return nremove;
56
+ }
57
+
58
+ void IndexFlatCodes::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
59
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
60
+ sa_decode(ni, codes.data() + i0 * code_size, recons);
61
+ }
62
+
63
+ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
64
+ reconstruct_n(key, 1, recons);
65
+ }
66
+
67
+ } // namespace faiss
@@ -0,0 +1,47 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <faiss/Index.h>
13
+ #include <vector>
14
+
15
+ namespace faiss {
16
+
17
+ /** Index that encodes all vectors as fixed-size codes (size code_size). Storage
18
+ * is in the codes vector */
19
+ struct IndexFlatCodes : Index {
20
+ size_t code_size;
21
+
22
+ /// encoded dataset, size ntotal * code_size
23
+ std::vector<uint8_t> codes;
24
+
25
+ IndexFlatCodes();
26
+
27
+ IndexFlatCodes(size_t code_size, idx_t d, MetricType metric = METRIC_L2);
28
+
29
+ /// default add uses sa_encode
30
+ void add(idx_t n, const float* x) override;
31
+
32
+ void reset() override;
33
+
34
+ /// reconstruction using the codec interface
35
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
36
+
37
+ void reconstruct(idx_t key, float* recons) const override;
38
+
39
+ size_t sa_code_size() const override;
40
+
41
+ /** remove some ids. NB that Because of the structure of the
42
+ * indexing structure, the semantics of this operation are
43
+ * different from the usual ones: the new ids are shifted */
44
+ size_t remove_ids(const IDSelector& sel) override;
45
+ };
46
+
47
+ } // namespace faiss