faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,132 +9,88 @@
9
9
 
10
10
  #include <faiss/IndexFlat.h>
11
11
 
12
- #include <cstring>
12
+ #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/utils/Heap.h>
13
15
  #include <faiss/utils/distances.h>
14
16
  #include <faiss/utils/extra_distances.h>
15
17
  #include <faiss/utils/utils.h>
16
- #include <faiss/utils/Heap.h>
17
- #include <faiss/impl/FaissAssert.h>
18
- #include <faiss/impl/AuxIndexStructures.h>
19
-
18
+ #include <cstring>
20
19
 
21
20
  namespace faiss {
22
21
 
23
- IndexFlat::IndexFlat (idx_t d, MetricType metric):
24
- Index(d, metric)
25
- {
26
- }
27
-
22
+ IndexFlat::IndexFlat(idx_t d, MetricType metric)
23
+ : IndexFlatCodes(sizeof(float) * d, d, metric) {}
28
24
 
25
+ void IndexFlat::search(
26
+ idx_t n,
27
+ const float* x,
28
+ idx_t k,
29
+ float* distances,
30
+ idx_t* labels) const {
31
+ FAISS_THROW_IF_NOT(k > 0);
29
32
 
30
- void IndexFlat::add (idx_t n, const float *x) {
31
- xb.insert(xb.end(), x, x + n * d);
32
- ntotal += n;
33
- }
34
-
35
-
36
- void IndexFlat::reset() {
37
- xb.clear();
38
- ntotal = 0;
39
- }
40
-
41
-
42
- void IndexFlat::search (idx_t n, const float *x, idx_t k,
43
- float *distances, idx_t *labels) const
44
- {
45
33
  // we see the distances and labels as heaps
46
34
 
47
35
  if (metric_type == METRIC_INNER_PRODUCT) {
48
- float_minheap_array_t res = {
49
- size_t(n), size_t(k), labels, distances};
50
- knn_inner_product (x, xb.data(), d, n, ntotal, &res);
36
+ float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
37
+ knn_inner_product(x, get_xb(), d, n, ntotal, &res);
51
38
  } else if (metric_type == METRIC_L2) {
52
- float_maxheap_array_t res = {
53
- size_t(n), size_t(k), labels, distances};
54
- knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
39
+ float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
40
+ knn_L2sqr(x, get_xb(), d, n, ntotal, &res);
55
41
  } else {
56
- float_maxheap_array_t res = {
57
- size_t(n), size_t(k), labels, distances};
58
- knn_extra_metrics (x, xb.data(), d, n, ntotal,
59
- metric_type, metric_arg,
60
- &res);
42
+ float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
43
+ knn_extra_metrics(
44
+ x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
61
45
  }
62
46
  }
63
47
 
64
- void IndexFlat::range_search (idx_t n, const float *x, float radius,
65
- RangeSearchResult *result) const
66
- {
48
+ void IndexFlat::range_search(
49
+ idx_t n,
50
+ const float* x,
51
+ float radius,
52
+ RangeSearchResult* result) const {
67
53
  switch (metric_type) {
68
- case METRIC_INNER_PRODUCT:
69
- range_search_inner_product (x, xb.data(), d, n, ntotal,
70
- radius, result);
71
- break;
72
- case METRIC_L2:
73
- range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
74
- break;
75
- default:
76
- FAISS_THROW_MSG("metric type not supported");
54
+ case METRIC_INNER_PRODUCT:
55
+ range_search_inner_product(
56
+ x, get_xb(), d, n, ntotal, radius, result);
57
+ break;
58
+ case METRIC_L2:
59
+ range_search_L2sqr(x, get_xb(), d, n, ntotal, radius, result);
60
+ break;
61
+ default:
62
+ FAISS_THROW_MSG("metric type not supported");
77
63
  }
78
64
  }
79
65
 
80
-
81
- void IndexFlat::compute_distance_subset (
82
- idx_t n,
83
- const float *x,
84
- idx_t k,
85
- float *distances,
86
- const idx_t *labels) const
87
- {
66
+ void IndexFlat::compute_distance_subset(
67
+ idx_t n,
68
+ const float* x,
69
+ idx_t k,
70
+ float* distances,
71
+ const idx_t* labels) const {
88
72
  switch (metric_type) {
89
73
  case METRIC_INNER_PRODUCT:
90
- fvec_inner_products_by_idx (
91
- distances,
92
- x, xb.data(), labels, d, n, k);
74
+ fvec_inner_products_by_idx(distances, x, get_xb(), labels, d, n, k);
93
75
  break;
94
76
  case METRIC_L2:
95
- fvec_L2sqr_by_idx (
96
- distances,
97
- x, xb.data(), labels, d, n, k);
77
+ fvec_L2sqr_by_idx(distances, x, get_xb(), labels, d, n, k);
98
78
  break;
99
79
  default:
100
80
  FAISS_THROW_MSG("metric type not supported");
101
81
  }
102
-
103
82
  }
104
83
 
105
- size_t IndexFlat::remove_ids (const IDSelector & sel)
106
- {
107
- idx_t j = 0;
108
- for (idx_t i = 0; i < ntotal; i++) {
109
- if (sel.is_member (i)) {
110
- // should be removed
111
- } else {
112
- if (i > j) {
113
- memmove (&xb[d * j], &xb[d * i], sizeof(xb[0]) * d);
114
- }
115
- j++;
116
- }
117
- }
118
- size_t nremove = ntotal - j;
119
- if (nremove > 0) {
120
- ntotal = j;
121
- xb.resize (ntotal * d);
122
- }
123
- return nremove;
124
- }
125
-
126
-
127
84
  namespace {
128
85
 
129
-
130
86
  struct FlatL2Dis : DistanceComputer {
131
87
  size_t d;
132
88
  Index::idx_t nb;
133
- const float *q;
134
- const float *b;
89
+ const float* q;
90
+ const float* b;
135
91
  size_t ndis;
136
92
 
137
- float operator () (idx_t i) override {
93
+ float operator()(idx_t i) override {
138
94
  ndis++;
139
95
  return fvec_L2sqr(q, b + i * d, d);
140
96
  }
@@ -143,14 +99,14 @@ struct FlatL2Dis : DistanceComputer {
143
99
  return fvec_L2sqr(b + j * d, b + i * d, d);
144
100
  }
145
101
 
146
- explicit FlatL2Dis(const IndexFlat& storage, const float *q = nullptr)
147
- : d(storage.d),
148
- nb(storage.ntotal),
149
- q(q),
150
- b(storage.xb.data()),
151
- ndis(0) {}
102
+ explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
103
+ : d(storage.d),
104
+ nb(storage.ntotal),
105
+ q(q),
106
+ b(storage.get_xb()),
107
+ ndis(0) {}
152
108
 
153
- void set_query(const float *x) override {
109
+ void set_query(const float* x) override {
154
110
  q = x;
155
111
  }
156
112
  };
@@ -158,128 +114,106 @@ struct FlatL2Dis : DistanceComputer {
158
114
  struct FlatIPDis : DistanceComputer {
159
115
  size_t d;
160
116
  Index::idx_t nb;
161
- const float *q;
162
- const float *b;
117
+ const float* q;
118
+ const float* b;
163
119
  size_t ndis;
164
120
 
165
- float operator () (idx_t i) override {
121
+ float operator()(idx_t i) override {
166
122
  ndis++;
167
- return fvec_inner_product (q, b + i * d, d);
123
+ return fvec_inner_product(q, b + i * d, d);
168
124
  }
169
125
 
170
126
  float symmetric_dis(idx_t i, idx_t j) override {
171
- return fvec_inner_product (b + j * d, b + i * d, d);
127
+ return fvec_inner_product(b + j * d, b + i * d, d);
172
128
  }
173
129
 
174
- explicit FlatIPDis(const IndexFlat& storage, const float *q = nullptr)
175
- : d(storage.d),
176
- nb(storage.ntotal),
177
- q(q),
178
- b(storage.xb.data()),
179
- ndis(0) {}
130
+ explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
131
+ : d(storage.d),
132
+ nb(storage.ntotal),
133
+ q(q),
134
+ b(storage.get_xb()),
135
+ ndis(0) {}
180
136
 
181
- void set_query(const float *x) override {
137
+ void set_query(const float* x) override {
182
138
  q = x;
183
139
  }
184
140
  };
185
141
 
142
+ } // namespace
186
143
 
187
-
188
-
189
- } // namespace
190
-
191
-
192
- DistanceComputer * IndexFlat::get_distance_computer() const {
144
+ DistanceComputer* IndexFlat::get_distance_computer() const {
193
145
  if (metric_type == METRIC_L2) {
194
146
  return new FlatL2Dis(*this);
195
147
  } else if (metric_type == METRIC_INNER_PRODUCT) {
196
148
  return new FlatIPDis(*this);
197
149
  } else {
198
- return get_extra_distance_computer (d, metric_type, metric_arg,
199
- ntotal, xb.data());
150
+ return get_extra_distance_computer(
151
+ d, metric_type, metric_arg, ntotal, get_xb());
200
152
  }
201
153
  }
202
154
 
203
-
204
- void IndexFlat::reconstruct (idx_t key, float * recons) const
205
- {
206
- memcpy (recons, &(xb[key * d]), sizeof(*recons) * d);
155
+ void IndexFlat::reconstruct(idx_t key, float* recons) const {
156
+ memcpy(recons, &(codes[key * code_size]), code_size);
207
157
  }
208
158
 
209
-
210
- /* The standalone codec interface */
211
- size_t IndexFlat::sa_code_size () const
212
- {
213
- return sizeof(float) * d;
214
- }
215
-
216
- void IndexFlat::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
217
- {
218
- memcpy (bytes, x, sizeof(float) * d * n);
159
+ void IndexFlat::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
160
+ if (n > 0) {
161
+ memcpy(bytes, x, sizeof(float) * d * n);
162
+ }
219
163
  }
220
164
 
221
- void IndexFlat::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
222
- {
223
- memcpy (x, bytes, sizeof(float) * d * n);
165
+ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
166
+ if (n > 0) {
167
+ memcpy(x, bytes, sizeof(float) * d * n);
168
+ }
224
169
  }
225
170
 
226
-
227
-
228
-
229
-
230
-
231
171
  /***************************************************
232
172
  * IndexFlat1D
233
173
  ***************************************************/
234
174
 
235
-
236
- IndexFlat1D::IndexFlat1D (bool continuous_update):
237
- IndexFlatL2 (1),
238
- continuous_update (continuous_update)
239
- {
240
- }
175
+ IndexFlat1D::IndexFlat1D(bool continuous_update)
176
+ : IndexFlatL2(1), continuous_update(continuous_update) {}
241
177
 
242
178
  /// if not continuous_update, call this between the last add and
243
179
  /// the first search
244
- void IndexFlat1D::update_permutation ()
245
- {
246
- perm.resize (ntotal);
180
+ void IndexFlat1D::update_permutation() {
181
+ perm.resize(ntotal);
247
182
  if (ntotal < 1000000) {
248
- fvec_argsort (ntotal, xb.data(), (size_t*)perm.data());
183
+ fvec_argsort(ntotal, get_xb(), (size_t*)perm.data());
249
184
  } else {
250
- fvec_argsort_parallel (ntotal, xb.data(), (size_t*)perm.data());
185
+ fvec_argsort_parallel(ntotal, get_xb(), (size_t*)perm.data());
251
186
  }
252
187
  }
253
188
 
254
- void IndexFlat1D::add (idx_t n, const float *x)
255
- {
256
- IndexFlatL2::add (n, x);
189
+ void IndexFlat1D::add(idx_t n, const float* x) {
190
+ IndexFlatL2::add(n, x);
257
191
  if (continuous_update)
258
192
  update_permutation();
259
193
  }
260
194
 
261
- void IndexFlat1D::reset()
262
- {
195
+ void IndexFlat1D::reset() {
263
196
  IndexFlatL2::reset();
264
197
  perm.clear();
265
198
  }
266
199
 
267
- void IndexFlat1D::search (
268
- idx_t n,
269
- const float *x,
270
- idx_t k,
271
- float *distances,
272
- idx_t *labels) const
273
- {
274
- FAISS_THROW_IF_NOT_MSG (perm.size() == ntotal,
275
- "Call update_permutation before search");
200
+ void IndexFlat1D::search(
201
+ idx_t n,
202
+ const float* x,
203
+ idx_t k,
204
+ float* distances,
205
+ idx_t* labels) const {
206
+ FAISS_THROW_IF_NOT(k > 0);
207
+
208
+ FAISS_THROW_IF_NOT_MSG(
209
+ perm.size() == ntotal, "Call update_permutation before search");
210
+ const float* xb = get_xb();
276
211
 
277
212
  #pragma omp parallel for
278
213
  for (idx_t i = 0; i < n; i++) {
279
-
280
214
  float q = x[i]; // query
281
- float *D = distances + i * k;
282
- idx_t *I = labels + i * k;
215
+ float* D = distances + i * k;
216
+ idx_t* I = labels + i * k;
283
217
 
284
218
  // binary search
285
219
  idx_t i0 = 0, i1 = ntotal;
@@ -297,8 +231,10 @@ void IndexFlat1D::search (
297
231
 
298
232
  while (i0 + 1 < i1) {
299
233
  idx_t imed = (i0 + i1) / 2;
300
- if (xb[perm[imed]] <= q) i0 = imed;
301
- else i1 = imed;
234
+ if (xb[perm[imed]] <= q)
235
+ i0 = imed;
236
+ else
237
+ i1 = imed;
302
238
  }
303
239
 
304
240
  // query is between xb[perm[i0]] and xb[perm[i1]]
@@ -311,13 +247,19 @@ void IndexFlat1D::search (
311
247
  if (q - xleft < xright - q) {
312
248
  D[wp] = q - xleft;
313
249
  I[wp] = perm[i0];
314
- i0--; wp++;
315
- if (i0 < 0) { goto finish_right; }
250
+ i0--;
251
+ wp++;
252
+ if (i0 < 0) {
253
+ goto finish_right;
254
+ }
316
255
  } else {
317
256
  D[wp] = xright - q;
318
257
  I[wp] = perm[i1];
319
- i1++; wp++;
320
- if (i1 >= ntotal) { goto finish_left; }
258
+ i1++;
259
+ wp++;
260
+ if (i1 >= ntotal) {
261
+ goto finish_left;
262
+ }
321
263
  }
322
264
  }
323
265
  goto done;
@@ -350,11 +292,8 @@ void IndexFlat1D::search (
350
292
  }
351
293
  wp++;
352
294
  }
353
- done: ;
295
+ done:;
354
296
  }
355
-
356
297
  }
357
298
 
358
-
359
-
360
299
  } // namespace faiss
@@ -12,35 +12,26 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
16
-
15
+ #include <faiss/IndexFlatCodes.h>
17
16
 
18
17
  namespace faiss {
19
18
 
20
19
  /** Index that stores the full vectors and performs exhaustive search */
21
- struct IndexFlat: Index {
22
-
23
- /// database vectors, size ntotal * d
24
- std::vector<float> xb;
25
-
26
- explicit IndexFlat (idx_t d, MetricType metric = METRIC_L2);
27
-
28
- void add(idx_t n, const float* x) override;
29
-
30
- void reset() override;
20
+ struct IndexFlat : IndexFlatCodes {
21
+ explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
31
22
 
32
23
  void search(
33
- idx_t n,
34
- const float* x,
35
- idx_t k,
36
- float* distances,
37
- idx_t* labels) const override;
24
+ idx_t n,
25
+ const float* x,
26
+ idx_t k,
27
+ float* distances,
28
+ idx_t* labels) const override;
38
29
 
39
30
  void range_search(
40
- idx_t n,
41
- const float* x,
42
- float radius,
43
- RangeSearchResult* result) const override;
31
+ idx_t n,
32
+ const float* x,
33
+ float radius,
34
+ RangeSearchResult* result) const override;
44
35
 
45
36
  void reconstruct(idx_t key, float* recons) const override;
46
37
 
@@ -52,59 +43,52 @@ struct IndexFlat: Index {
52
43
  * @param distances
53
44
  * corresponding output distances, size n * k
54
45
  */
55
- void compute_distance_subset (
46
+ void compute_distance_subset(
56
47
  idx_t n,
57
- const float *x,
48
+ const float* x,
58
49
  idx_t k,
59
- float *distances,
60
- const idx_t *labels) const;
50
+ float* distances,
51
+ const idx_t* labels) const;
61
52
 
62
- /** remove some ids. NB that Because of the structure of the
63
- * indexing structure, the semantics of this operation are
64
- * different from the usual ones: the new ids are shifted */
65
- size_t remove_ids(const IDSelector& sel) override;
53
+ // get pointer to the floating point data
54
+ float* get_xb() {
55
+ return (float*)codes.data();
56
+ }
57
+ const float* get_xb() const {
58
+ return (const float*)codes.data();
59
+ }
66
60
 
67
- IndexFlat () {}
61
+ IndexFlat() {}
68
62
 
69
- DistanceComputer * get_distance_computer() const override;
63
+ DistanceComputer* get_distance_computer() const override;
70
64
 
71
65
  /* The stanadlone codec interface (just memcopies in this case) */
72
- size_t sa_code_size () const override;
73
-
74
- void sa_encode (idx_t n, const float *x,
75
- uint8_t *bytes) const override;
76
-
77
- void sa_decode (idx_t n, const uint8_t *bytes,
78
- float *x) const override;
66
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
79
67
 
68
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
80
69
  };
81
70
 
82
-
83
-
84
- struct IndexFlatIP:IndexFlat {
85
- explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {}
86
- IndexFlatIP () {}
71
+ struct IndexFlatIP : IndexFlat {
72
+ explicit IndexFlatIP(idx_t d) : IndexFlat(d, METRIC_INNER_PRODUCT) {}
73
+ IndexFlatIP() {}
87
74
  };
88
75
 
89
-
90
- struct IndexFlatL2:IndexFlat {
91
- explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {}
92
- IndexFlatL2 () {}
76
+ struct IndexFlatL2 : IndexFlat {
77
+ explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
78
+ IndexFlatL2() {}
93
79
  };
94
80
 
95
-
96
-
97
81
  /// optimized version for 1D "vectors".
98
- struct IndexFlat1D:IndexFlatL2 {
82
+ struct IndexFlat1D : IndexFlatL2 {
99
83
  bool continuous_update; ///< is the permutation updated continuously?
100
84
 
101
85
  std::vector<idx_t> perm; ///< sorted database indices
102
86
 
103
- explicit IndexFlat1D (bool continuous_update=true);
87
+ explicit IndexFlat1D(bool continuous_update = true);
104
88
 
105
89
  /// if not continuous_update, call this between the last add and
106
90
  /// the first search
107
- void update_permutation ();
91
+ void update_permutation();
108
92
 
109
93
  void add(idx_t n, const float* x) override;
110
94
 
@@ -112,14 +96,13 @@ struct IndexFlat1D:IndexFlatL2 {
112
96
 
113
97
  /// Warn: the distances returned are L1 not L2
114
98
  void search(
115
- idx_t n,
116
- const float* x,
117
- idx_t k,
118
- float* distances,
119
- idx_t* labels) const override;
99
+ idx_t n,
100
+ const float* x,
101
+ idx_t k,
102
+ float* distances,
103
+ idx_t* labels) const override;
120
104
  };
121
105
 
122
-
123
- }
106
+ } // namespace faiss
124
107
 
125
108
  #endif
@@ -0,0 +1,67 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexFlatCodes.h>
9
+
10
+ #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/FaissAssert.h>
12
+
13
+ namespace faiss {
14
+
15
+ IndexFlatCodes::IndexFlatCodes(size_t code_size, idx_t d, MetricType metric)
16
+ : Index(d, metric), code_size(code_size) {}
17
+
18
+ IndexFlatCodes::IndexFlatCodes() : code_size(0) {}
19
+
20
+ void IndexFlatCodes::add(idx_t n, const float* x) {
21
+ FAISS_THROW_IF_NOT(is_trained);
22
+ codes.resize((ntotal + n) * code_size);
23
+ sa_encode(n, x, &codes[ntotal * code_size]);
24
+ ntotal += n;
25
+ }
26
+
27
+ void IndexFlatCodes::reset() {
28
+ codes.clear();
29
+ ntotal = 0;
30
+ }
31
+
32
+ size_t IndexFlatCodes::sa_code_size() const {
33
+ return code_size;
34
+ }
35
+
36
+ size_t IndexFlatCodes::remove_ids(const IDSelector& sel) {
37
+ idx_t j = 0;
38
+ for (idx_t i = 0; i < ntotal; i++) {
39
+ if (sel.is_member(i)) {
40
+ // should be removed
41
+ } else {
42
+ if (i > j) {
43
+ memmove(&codes[code_size * j],
44
+ &codes[code_size * i],
45
+ code_size);
46
+ }
47
+ j++;
48
+ }
49
+ }
50
+ size_t nremove = ntotal - j;
51
+ if (nremove > 0) {
52
+ ntotal = j;
53
+ codes.resize(ntotal * code_size);
54
+ }
55
+ return nremove;
56
+ }
57
+
58
+ void IndexFlatCodes::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
59
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
60
+ sa_decode(ni, codes.data() + i0 * code_size, recons);
61
+ }
62
+
63
+ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
64
+ reconstruct_n(key, 1, recons);
65
+ }
66
+
67
+ } // namespace faiss
@@ -0,0 +1,47 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <faiss/Index.h>
13
+ #include <vector>
14
+
15
+ namespace faiss {
16
+
17
+ /** Index that encodes all vectors as fixed-size codes (size code_size). Storage
18
+ * is in the codes vector */
19
+ struct IndexFlatCodes : Index {
20
+ size_t code_size;
21
+
22
+ /// encoded dataset, size ntotal * code_size
23
+ std::vector<uint8_t> codes;
24
+
25
+ IndexFlatCodes();
26
+
27
+ IndexFlatCodes(size_t code_size, idx_t d, MetricType metric = METRIC_L2);
28
+
29
+ /// default add uses sa_encode
30
+ void add(idx_t n, const float* x) override;
31
+
32
+ void reset() override;
33
+
34
+ /// reconstruction using the codec interface
35
+ void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
36
+
37
+ void reconstruct(idx_t key, float* recons) const override;
38
+
39
+ size_t sa_code_size() const override;
40
+
41
+ /** remove some ids. NB that Because of the structure of the
42
+ * indexing structure, the semantics of this operation are
43
+ * different from the usual ones: the new ids are shifted */
44
+ size_t remove_ids(const IDSelector& sel) override;
45
+ };
46
+
47
+ } // namespace faiss