faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -15,6 +15,8 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
+ struct IDSelector;
19
+
18
20
  // When offsets list id + offset are encoded in an uint64
19
21
  // we call this LO = list-offset
20
22
 
@@ -34,8 +36,6 @@ inline uint64_t lo_offset(uint64_t lo) {
34
36
  * Direct map: a way to map back from ids to inverted lists
35
37
  */
36
38
  struct DirectMap {
37
- typedef Index::idx_t idx_t;
38
-
39
39
  enum Type {
40
40
  NoMap = 0, // default
41
41
  Array = 1, // sequential ids (only for add, no add_with_ids)
@@ -91,8 +91,6 @@ struct DirectMap {
91
91
 
92
92
  /// Thread-safe way of updating the direct_map
93
93
  struct DirectMapAdd {
94
- typedef Index::idx_t idx_t;
95
-
96
94
  using Type = DirectMap::Type;
97
95
 
98
96
  DirectMap& direct_map;
@@ -10,23 +10,32 @@
10
10
  #include <faiss/invlists/InvertedLists.h>
11
11
 
12
12
  #include <cstdio>
13
+ #include <memory>
13
14
 
14
15
  #include <faiss/impl/FaissAssert.h>
15
16
  #include <faiss/utils/utils.h>
16
17
 
17
18
  namespace faiss {
18
19
 
20
+ InvertedListsIterator::~InvertedListsIterator() {}
21
+
19
22
  /*****************************************
20
23
  * InvertedLists implementation
21
24
  ******************************************/
22
25
 
23
26
  InvertedLists::InvertedLists(size_t nlist, size_t code_size)
24
- : nlist(nlist), code_size(code_size) {}
27
+ : nlist(nlist), code_size(code_size), use_iterator(false) {}
25
28
 
26
29
  InvertedLists::~InvertedLists() {}
27
30
 
28
- InvertedLists::idx_t InvertedLists::get_single_id(size_t list_no, size_t offset)
29
- const {
31
+ bool InvertedLists::is_empty(size_t list_no) const {
32
+ return use_iterator
33
+ ? !std::unique_ptr<InvertedListsIterator>(get_iterator(list_no))
34
+ ->is_available()
35
+ : list_size(list_no) == 0;
36
+ }
37
+
38
+ idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const {
30
39
  assert(offset < list_size(list_no));
31
40
  const idx_t* ids = get_ids(list_no);
32
41
  idx_t id = ids[offset];
@@ -67,6 +76,10 @@ void InvertedLists::reset() {
67
76
  }
68
77
  }
69
78
 
79
+ InvertedListsIterator* InvertedLists::get_iterator(size_t /*list_no*/) const {
80
+ FAISS_THROW_MSG("get_iterator is not supported");
81
+ }
82
+
70
83
  void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
71
84
  #pragma omp parallel for
72
85
  for (idx_t i = 0; i < nlist; i++) {
@@ -87,6 +100,98 @@ void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
87
100
  }
88
101
  }
89
102
 
103
+ size_t InvertedLists::copy_subset_to(
104
+ InvertedLists& oivf,
105
+ subset_type_t subset_type,
106
+ idx_t a1,
107
+ idx_t a2) const {
108
+ FAISS_THROW_IF_NOT(nlist == oivf.nlist);
109
+ FAISS_THROW_IF_NOT(code_size == oivf.code_size);
110
+ FAISS_THROW_IF_NOT_FMT(
111
+ subset_type >= 0 && subset_type <= 4,
112
+ "subset type %d not implemented",
113
+ subset_type);
114
+ size_t accu_n = 0;
115
+ size_t accu_a1 = 0;
116
+ size_t accu_a2 = 0;
117
+ size_t n_added = 0;
118
+
119
+ size_t ntotal = 0;
120
+ if (subset_type == 2) {
121
+ ntotal = compute_ntotal();
122
+ }
123
+
124
+ for (idx_t list_no = 0; list_no < nlist; list_no++) {
125
+ size_t n = list_size(list_no);
126
+ ScopedIds ids_in(this, list_no);
127
+
128
+ if (subset_type == SUBSET_TYPE_ID_RANGE) {
129
+ for (idx_t i = 0; i < n; i++) {
130
+ idx_t id = ids_in[i];
131
+ if (a1 <= id && id < a2) {
132
+ oivf.add_entry(
133
+ list_no,
134
+ get_single_id(list_no, i),
135
+ ScopedCodes(this, list_no, i).get());
136
+ n_added++;
137
+ }
138
+ }
139
+ } else if (subset_type == SUBSET_TYPE_ID_MOD) {
140
+ for (idx_t i = 0; i < n; i++) {
141
+ idx_t id = ids_in[i];
142
+ if (id % a1 == a2) {
143
+ oivf.add_entry(
144
+ list_no,
145
+ get_single_id(list_no, i),
146
+ ScopedCodes(this, list_no, i).get());
147
+ n_added++;
148
+ }
149
+ }
150
+ } else if (subset_type == SUBSET_TYPE_ELEMENT_RANGE) {
151
+ // see what is allocated to a1 and to a2
152
+ size_t next_accu_n = accu_n + n;
153
+ size_t next_accu_a1 = next_accu_n * a1 / ntotal;
154
+ size_t i1 = next_accu_a1 - accu_a1;
155
+ size_t next_accu_a2 = next_accu_n * a2 / ntotal;
156
+ size_t i2 = next_accu_a2 - accu_a2;
157
+
158
+ for (idx_t i = i1; i < i2; i++) {
159
+ oivf.add_entry(
160
+ list_no,
161
+ get_single_id(list_no, i),
162
+ ScopedCodes(this, list_no, i).get());
163
+ }
164
+
165
+ n_added += i2 - i1;
166
+ accu_a1 = next_accu_a1;
167
+ accu_a2 = next_accu_a2;
168
+ } else if (subset_type == SUBSET_TYPE_INVLIST_FRACTION) {
169
+ size_t i1 = n * a2 / a1;
170
+ size_t i2 = n * (a2 + 1) / a1;
171
+
172
+ for (idx_t i = i1; i < i2; i++) {
173
+ oivf.add_entry(
174
+ list_no,
175
+ get_single_id(list_no, i),
176
+ ScopedCodes(this, list_no, i).get());
177
+ }
178
+
179
+ n_added += i2 - i1;
180
+ } else if (subset_type == SUBSET_TYPE_INVLIST) {
181
+ if (list_no >= a1 && list_no < a2) {
182
+ oivf.add_entries(
183
+ list_no,
184
+ n,
185
+ ScopedIds(this, list_no).get(),
186
+ ScopedCodes(this, list_no).get());
187
+ n_added += n;
188
+ }
189
+ }
190
+ accu_n += n;
191
+ }
192
+ return n_added;
193
+ }
194
+
90
195
  double InvertedLists::imbalance_factor() const {
91
196
  std::vector<int> hist(nlist);
92
197
 
@@ -109,7 +214,9 @@ void InvertedLists::print_stats() const {
109
214
  }
110
215
  for (size_t i = 0; i < sizes.size(); i++) {
111
216
  if (sizes[i]) {
112
- printf("list size in < %d: %d instances\n", 1 << i, sizes[i]);
217
+ printf("list size in < %zu: %d instances\n",
218
+ static_cast<size_t>(1) << i,
219
+ sizes[i]);
113
220
  }
114
221
  }
115
222
  }
@@ -158,7 +265,7 @@ const uint8_t* ArrayInvertedLists::get_codes(size_t list_no) const {
158
265
  return codes[list_no].data();
159
266
  }
160
267
 
161
- const InvertedLists::idx_t* ArrayInvertedLists::get_ids(size_t list_no) const {
268
+ const idx_t* ArrayInvertedLists::get_ids(size_t list_no) const {
162
269
  assert(list_no < nlist);
163
270
  return ids[list_no].data();
164
271
  }
@@ -267,7 +374,7 @@ void HStackInvertedLists::release_codes(size_t, const uint8_t* codes) const {
267
374
  delete[] codes;
268
375
  }
269
376
 
270
- const Index::idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
377
+ const idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
271
378
  idx_t *ids = new idx_t[list_size(list_no)], *c = ids;
272
379
 
273
380
  for (int i = 0; i < ils.size(); i++) {
@@ -281,8 +388,7 @@ const Index::idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
281
388
  return ids;
282
389
  }
283
390
 
284
- Index::idx_t HStackInvertedLists::get_single_id(size_t list_no, size_t offset)
285
- const {
391
+ idx_t HStackInvertedLists::get_single_id(size_t list_no, size_t offset) const {
286
392
  for (int i = 0; i < ils.size(); i++) {
287
393
  const InvertedLists* il = ils[i];
288
394
  size_t sz = il->list_size(list_no);
@@ -312,8 +418,6 @@ void HStackInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist)
312
418
 
313
419
  namespace {
314
420
 
315
- using idx_t = InvertedLists::idx_t;
316
-
317
421
  idx_t translate_list_no(const SliceInvertedLists* sil, idx_t list_no) {
318
422
  FAISS_THROW_IF_NOT(list_no >= 0 && list_no < sil->nlist);
319
423
  return list_no + sil->i0;
@@ -349,12 +453,11 @@ void SliceInvertedLists::release_codes(size_t list_no, const uint8_t* codes)
349
453
  return il->release_codes(translate_list_no(this, list_no), codes);
350
454
  }
351
455
 
352
- const Index::idx_t* SliceInvertedLists::get_ids(size_t list_no) const {
456
+ const idx_t* SliceInvertedLists::get_ids(size_t list_no) const {
353
457
  return il->get_ids(translate_list_no(this, list_no));
354
458
  }
355
459
 
356
- Index::idx_t SliceInvertedLists::get_single_id(size_t list_no, size_t offset)
357
- const {
460
+ idx_t SliceInvertedLists::get_single_id(size_t list_no, size_t offset) const {
358
461
  return il->get_single_id(translate_list_no(this, list_no), offset);
359
462
  }
360
463
 
@@ -380,8 +483,6 @@ void SliceInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist)
380
483
 
381
484
  namespace {
382
485
 
383
- using idx_t = InvertedLists::idx_t;
384
-
385
486
  // find the invlist this number belongs to
386
487
  int translate_list_no(const VStackInvertedLists* vil, idx_t list_no) {
387
488
  FAISS_THROW_IF_NOT(list_no >= 0 && list_no < vil->nlist);
@@ -449,14 +550,13 @@ void VStackInvertedLists::release_codes(size_t list_no, const uint8_t* codes)
449
550
  return ils[i]->release_codes(list_no, codes);
450
551
  }
451
552
 
452
- const Index::idx_t* VStackInvertedLists::get_ids(size_t list_no) const {
553
+ const idx_t* VStackInvertedLists::get_ids(size_t list_no) const {
453
554
  int i = translate_list_no(this, list_no);
454
555
  list_no -= cumsz[i];
455
556
  return ils[i]->get_ids(list_no);
456
557
  }
457
558
 
458
- Index::idx_t VStackInvertedLists::get_single_id(size_t list_no, size_t offset)
459
- const {
559
+ idx_t VStackInvertedLists::get_single_id(size_t list_no, size_t offset) const {
460
560
  int i = translate_list_no(this, list_no);
461
561
  list_no -= cumsz[i];
462
562
  return ils[i]->get_single_id(list_no, offset);
@@ -15,11 +15,18 @@
15
15
  * the interface.
16
16
  */
17
17
 
18
- #include <faiss/Index.h>
18
+ #include <faiss/MetricType.h>
19
19
  #include <vector>
20
20
 
21
21
  namespace faiss {
22
22
 
23
+ struct InvertedListsIterator {
24
+ virtual ~InvertedListsIterator();
25
+ virtual bool is_available() const = 0;
26
+ virtual void next() = 0;
27
+ virtual std::pair<idx_t, const uint8_t*> get_id_and_codes() = 0;
28
+ };
29
+
23
30
  /** Table of inverted lists
24
31
  * multithreading rules:
25
32
  * - concurrent read accesses are allowed
@@ -28,13 +35,14 @@ namespace faiss {
28
35
  * are allowed
29
36
  */
30
37
  struct InvertedLists {
31
- typedef Index::idx_t idx_t;
32
-
33
38
  size_t nlist; ///< number of possible key values
34
39
  size_t code_size; ///< code size per vector in bytes
40
+ bool use_iterator;
35
41
 
36
42
  InvertedLists(size_t nlist, size_t code_size);
37
43
 
44
+ virtual ~InvertedLists();
45
+
38
46
  /// used for BlockInvertedLists, where the codes are packed into groups
39
47
  /// and the individual code size is meaningless
40
48
  static const size_t INVALID_CODE_SIZE = static_cast<size_t>(-1);
@@ -42,9 +50,15 @@ struct InvertedLists {
42
50
  /*************************
43
51
  * Read only functions */
44
52
 
53
+ // check if the list is empty
54
+ bool is_empty(size_t list_no) const;
55
+
45
56
  /// get the size of a list
46
57
  virtual size_t list_size(size_t list_no) const = 0;
47
58
 
59
+ /// get iterable for lists that use_iterator
60
+ virtual InvertedListsIterator* get_iterator(size_t list_no) const;
61
+
48
62
  /** get the codes for an inverted list
49
63
  * must be released by release_codes
50
64
  *
@@ -105,10 +119,36 @@ struct InvertedLists {
105
119
 
106
120
  virtual void reset();
107
121
 
122
+ /*************************
123
+ * high level functions */
124
+
108
125
  /// move all entries from oivf (empty on output)
109
126
  void merge_from(InvertedLists* oivf, size_t add_id);
110
127
 
111
- virtual ~InvertedLists();
128
+ // how to copy a subset of elements from the inverted lists
129
+ // This depends on two integers, a1 and a2.
130
+ enum subset_type_t : int {
131
+ // depends on IDs
132
+ SUBSET_TYPE_ID_RANGE = 0, // copies ids in [a1, a2)
133
+ SUBSET_TYPE_ID_MOD = 1, // copies ids if id % a1 == a2
134
+ // depends on order within invlists
135
+ SUBSET_TYPE_ELEMENT_RANGE =
136
+ 2, // copies fractions of invlists so that a1 elements are left
137
+ // before and a2 after
138
+ SUBSET_TYPE_INVLIST_FRACTION =
139
+ 3, // take fraction a2 out of a1 from each invlist, 0 <= a2 < a1
140
+ // copy only inverted lists a1:a2
141
+ SUBSET_TYPE_INVLIST = 4
142
+ };
143
+
144
+ /** copy a subset of the entries index to the other index
145
+ * @return number of entries copied
146
+ */
147
+ size_t copy_subset_to(
148
+ InvertedLists& other,
149
+ subset_type_t subset_type,
150
+ idx_t a1,
151
+ idx_t a2) const;
112
152
 
113
153
  /*************************
114
154
  * statistics */
@@ -154,7 +154,7 @@ struct OnDiskInvertedLists::OngoingPrefetch {
154
154
  const OnDiskInvertedLists* od = pf->od;
155
155
  od->locks->lock_1(list_no);
156
156
  size_t n = od->list_size(list_no);
157
- const Index::idx_t* idx = od->get_ids(list_no);
157
+ const idx_t* idx = od->get_ids(list_no);
158
158
  const uint8_t* codes = od->get_codes(list_no);
159
159
  int cs = 0;
160
160
  for (size_t i = 0; i < n; i++) {
@@ -389,7 +389,7 @@ const uint8_t* OnDiskInvertedLists::get_codes(size_t list_no) const {
389
389
  return ptr + lists[list_no].offset;
390
390
  }
391
391
 
392
- const Index::idx_t* OnDiskInvertedLists::get_ids(size_t list_no) const {
392
+ const idx_t* OnDiskInvertedLists::get_ids(size_t list_no) const {
393
393
  if (lists[list_no].offset == INVALID_OFFSET) {
394
394
  return nullptr;
395
395
  }
@@ -781,7 +781,7 @@ InvertedLists* OnDiskInvertedListsIOHook::read_ArrayInvertedLists(
781
781
  OnDiskInvertedLists::List& l = ails->lists[i];
782
782
  l.size = l.capacity = sizes[i];
783
783
  l.offset = o;
784
- o += l.size * (sizeof(OnDiskInvertedLists::idx_t) + ails->code_size);
784
+ o += l.size * (sizeof(idx_t) + ails->code_size);
785
785
  }
786
786
  // resume normal reading of file
787
787
  fseek(fdesc, o, SEEK_SET);
@@ -31,7 +31,7 @@ struct OnDiskOneList {
31
31
 
32
32
  /** On-disk storage of inverted lists.
33
33
  *
34
- * The data is stored in a mmapped chunk of memory (base ptointer ptr,
34
+ * The data is stored in a mmapped chunk of memory (base pointer ptr,
35
35
  * size totsize). Each list is a range of memory that contains (object
36
36
  * List) that contains:
37
37
  *
@@ -118,7 +118,7 @@ PyCallbackIDSelector::PyCallbackIDSelector(PyObject* callback)
118
118
  Py_INCREF(callback);
119
119
  }
120
120
 
121
- bool PyCallbackIDSelector::is_member(idx_t id) const {
121
+ bool PyCallbackIDSelector::is_member(faiss::idx_t id) const {
122
122
  FAISS_THROW_IF_NOT((id >> 32) == 0);
123
123
  PyThreadLock gil;
124
124
  PyObject* result = PyObject_CallFunction(callback, "(n)", int(id));
@@ -54,7 +54,7 @@ struct PyCallbackIDSelector : faiss::IDSelector {
54
54
 
55
55
  explicit PyCallbackIDSelector(PyObject* callback);
56
56
 
57
- bool is_member(idx_t id) const override;
57
+ bool is_member(faiss::idx_t id) const override;
58
58
 
59
59
  ~PyCallbackIDSelector() override;
60
60
  };
@@ -98,7 +98,9 @@ struct AlignedTableTightAlloc {
98
98
  AlignedTableTightAlloc<T, A>& operator=(
99
99
  const AlignedTableTightAlloc<T, A>& other) {
100
100
  resize(other.numel);
101
- memcpy(ptr, other.ptr, sizeof(T) * numel);
101
+ if (numel > 0) {
102
+ memcpy(ptr, other.ptr, sizeof(T) * numel);
103
+ }
102
104
  return *this;
103
105
  }
104
106
 
@@ -9,6 +9,7 @@
9
9
 
10
10
  /* Function for soft heap */
11
11
 
12
+ #include <faiss/impl/FaissAssert.h>
12
13
  #include <faiss/utils/Heap.h>
13
14
 
14
15
  namespace faiss {
@@ -32,7 +33,7 @@ void HeapArray<C>::addn(size_t nj, const T* vin, TI j0, size_t i0, int64_t ni) {
32
33
  if (ni == -1)
33
34
  ni = nh;
34
35
  assert(i0 >= 0 && i0 + ni <= nh);
35
- #pragma omp parallel for
36
+ #pragma omp parallel for if (ni * nj > 100000)
36
37
  for (int64_t i = i0; i < i0 + ni; i++) {
37
38
  T* __restrict simi = get_val(i);
38
39
  TI* __restrict idxi = get_ids(i);
@@ -62,7 +63,7 @@ void HeapArray<C>::addn_with_ids(
62
63
  if (ni == -1)
63
64
  ni = nh;
64
65
  assert(i0 >= 0 && i0 + ni <= nh);
65
- #pragma omp parallel for
66
+ #pragma omp parallel for if (ni * nj > 100000)
66
67
  for (int64_t i = i0; i < i0 + ni; i++) {
67
68
  T* __restrict simi = get_val(i);
68
69
  TI* __restrict idxi = get_ids(i);
@@ -78,9 +79,38 @@ void HeapArray<C>::addn_with_ids(
78
79
  }
79
80
  }
80
81
 
82
+ template <typename C>
83
+ void HeapArray<C>::addn_query_subset_with_ids(
84
+ size_t nsubset,
85
+ const TI* subset,
86
+ size_t nj,
87
+ const T* vin,
88
+ const TI* id_in,
89
+ int64_t id_stride) {
90
+ FAISS_THROW_IF_NOT_MSG(id_in, "anonymous ids not supported");
91
+ if (id_stride < 0) {
92
+ id_stride = nj;
93
+ }
94
+ #pragma omp parallel for if (nsubset * nj > 100000)
95
+ for (int64_t si = 0; si < nsubset; si++) {
96
+ T i = subset[si];
97
+ T* __restrict simi = get_val(i);
98
+ TI* __restrict idxi = get_ids(i);
99
+ const T* ip_line = vin + si * nj;
100
+ const TI* id_line = id_in + si * id_stride;
101
+
102
+ for (size_t j = 0; j < nj; j++) {
103
+ T ip = ip_line[j];
104
+ if (C::cmp(simi[0], ip)) {
105
+ heap_replace_top<C>(k, simi, idxi, ip, id_line[j]);
106
+ }
107
+ }
108
+ }
109
+ }
110
+
81
111
  template <typename C>
82
112
  void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {
83
- #pragma omp parallel for
113
+ #pragma omp parallel for if (nh * k > 100000)
84
114
  for (int64_t j = 0; j < nh; j++) {
85
115
  int64_t imin = -1;
86
116
  typename C::T xval = C::Crev::neutral();
@@ -109,4 +139,110 @@ template struct HeapArray<CMax<float, int64_t>>;
109
139
  template struct HeapArray<CMin<int, int64_t>>;
110
140
  template struct HeapArray<CMax<int, int64_t>>;
111
141
 
142
+ /**********************************************************
143
+ * merge knn search results
144
+ **********************************************************/
145
+
146
+ /** Merge result tables from several shards. The per-shard results are assumed
147
+ * to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
148
+ * element heap because we want the best (ie. lowest for L2) result to be on
149
+ * top, not the worst.
150
+ *
151
+ * @param all_distances size (nshard, n, k)
152
+ * @param all_labels size (nshard, n, k)
153
+ * @param distances output distances, size (n, k)
154
+ * @param labels output labels, size (n, k)
155
+ */
156
+ template <class idx_t, class C>
157
+ void merge_knn_results(
158
+ size_t n,
159
+ size_t k,
160
+ typename C::TI nshard,
161
+ const typename C::T* all_distances,
162
+ const idx_t* all_labels,
163
+ typename C::T* distances,
164
+ idx_t* labels) {
165
+ using distance_t = typename C::T;
166
+ if (k == 0) {
167
+ return;
168
+ }
169
+ long stride = n * k;
170
+ #pragma omp parallel if (n * nshard * k > 100000)
171
+ {
172
+ std::vector<int> buf(2 * nshard);
173
+ // index in each shard's result list
174
+ int* pointer = buf.data();
175
+ // (shard_ids, heap_vals): heap that indexes
176
+ // shard -> current distance for this shard
177
+ int* shard_ids = pointer + nshard;
178
+ std::vector<distance_t> buf2(nshard);
179
+ distance_t* heap_vals = buf2.data();
180
+ #pragma omp for
181
+ for (long i = 0; i < n; i++) {
182
+ // the heap maps values to the shard where they are
183
+ // produced.
184
+ const distance_t* D_in = all_distances + i * k;
185
+ const idx_t* I_in = all_labels + i * k;
186
+ int heap_size = 0;
187
+
188
+ // push the first element of each shard (if not -1)
189
+ for (long s = 0; s < nshard; s++) {
190
+ pointer[s] = 0;
191
+ if (I_in[stride * s] >= 0) {
192
+ heap_push<C>(
193
+ ++heap_size,
194
+ heap_vals,
195
+ shard_ids,
196
+ D_in[stride * s],
197
+ s);
198
+ }
199
+ }
200
+
201
+ distance_t* D = distances + i * k;
202
+ idx_t* I = labels + i * k;
203
+
204
+ int j;
205
+ for (j = 0; j < k && heap_size > 0; j++) {
206
+ // pop element from best shard
207
+ int s = shard_ids[0]; // top of heap
208
+ int& p = pointer[s];
209
+ D[j] = heap_vals[0];
210
+ I[j] = I_in[stride * s + p];
211
+
212
+ // pop from shard, advance pointer for this shard
213
+ heap_pop<C>(heap_size--, heap_vals, shard_ids);
214
+ p++;
215
+ if (p < k && I_in[stride * s + p] >= 0) {
216
+ heap_push<C>(
217
+ ++heap_size,
218
+ heap_vals,
219
+ shard_ids,
220
+ D_in[stride * s + p],
221
+ s);
222
+ }
223
+ }
224
+ for (; j < k; j++) {
225
+ I[j] = -1;
226
+ D[j] = C::Crev::neutral();
227
+ }
228
+ }
229
+ }
230
+ }
231
+
232
+ // explicit instanciations
233
+ #define INSTANTIATE(C, distance_t) \
234
+ template void merge_knn_results<int64_t, C<distance_t, int>>( \
235
+ size_t, \
236
+ size_t, \
237
+ int, \
238
+ const distance_t*, \
239
+ const int64_t*, \
240
+ distance_t*, \
241
+ int64_t*);
242
+
243
+ INSTANTIATE(CMin, float);
244
+ INSTANTIATE(CMax, float);
245
+ INSTANTIATE(CMin, int32_t);
246
+ INSTANTIATE(CMax, int32_t);
247
+
112
248
  } // namespace faiss
@@ -413,6 +413,19 @@ struct HeapArray {
413
413
  size_t i0 = 0,
414
414
  int64_t ni = -1);
415
415
 
416
+ /** same as addn_with_ids, but for just a subset of queries
417
+ *
418
+ * @param nsubset number of query entries to update
419
+ * @param subset indexes of queries to update, in 0..nh-1, size nsubset
420
+ */
421
+ void addn_query_subset_with_ids(
422
+ size_t nsubset,
423
+ const TI* subset,
424
+ size_t nj,
425
+ const T* vin,
426
+ const TI* id_in = nullptr,
427
+ int64_t id_stride = 0);
428
+
416
429
  /// reorder all the heaps
417
430
  void reorder();
418
431
 
@@ -431,7 +444,7 @@ typedef HeapArray<CMin<int, int64_t>> int_minheap_array_t;
431
444
  typedef HeapArray<CMax<float, int64_t>> float_maxheap_array_t;
432
445
  typedef HeapArray<CMax<int, int64_t>> int_maxheap_array_t;
433
446
 
434
- // The heap templates are instanciated explicitly in Heap.cpp
447
+ // The heap templates are instantiated explicitly in Heap.cpp
435
448
 
436
449
  /*********************************************************************
437
450
  * Indirect heaps: instead of having
@@ -492,6 +505,27 @@ inline void indirect_heap_push(
492
505
  bh_ids[i] = id;
493
506
  }
494
507
 
508
+ /** Merge result tables from several shards. The per-shard results are assumed
509
+ * to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
510
+ * element heap because we want the best (ie. lowest for L2) result to be on
511
+ * top, not the worst. Also, it needs to hold an index of a shard id (ie.
512
+ * usually int32 is more than enough).
513
+ *
514
+ * @param all_distances size (nshard, n, k)
515
+ * @param all_labels size (nshard, n, k)
516
+ * @param distances output distances, size (n, k)
517
+ * @param labels output labels, size (n, k)
518
+ */
519
+ template <class idx_t, class C>
520
+ void merge_knn_results(
521
+ size_t n,
522
+ size_t k,
523
+ typename C::TI nshard,
524
+ const typename C::T* all_distances,
525
+ const idx_t* all_labels,
526
+ typename C::T* distances,
527
+ idx_t* labels);
528
+
495
529
  } // namespace faiss
496
530
 
497
531
  #endif /* FAISS_Heap_h */