faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -15,6 +15,8 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
+ struct IDSelector;
19
+
18
20
  // When offsets list id + offset are encoded in an uint64
19
21
  // we call this LO = list-offset
20
22
 
@@ -34,8 +36,6 @@ inline uint64_t lo_offset(uint64_t lo) {
34
36
  * Direct map: a way to map back from ids to inverted lists
35
37
  */
36
38
  struct DirectMap {
37
- typedef Index::idx_t idx_t;
38
-
39
39
  enum Type {
40
40
  NoMap = 0, // default
41
41
  Array = 1, // sequential ids (only for add, no add_with_ids)
@@ -91,8 +91,6 @@ struct DirectMap {
91
91
 
92
92
  /// Thread-safe way of updating the direct_map
93
93
  struct DirectMapAdd {
94
- typedef Index::idx_t idx_t;
95
-
96
94
  using Type = DirectMap::Type;
97
95
 
98
96
  DirectMap& direct_map;
@@ -10,23 +10,32 @@
10
10
  #include <faiss/invlists/InvertedLists.h>
11
11
 
12
12
  #include <cstdio>
13
+ #include <memory>
13
14
 
14
15
  #include <faiss/impl/FaissAssert.h>
15
16
  #include <faiss/utils/utils.h>
16
17
 
17
18
  namespace faiss {
18
19
 
20
+ InvertedListsIterator::~InvertedListsIterator() {}
21
+
19
22
  /*****************************************
20
23
  * InvertedLists implementation
21
24
  ******************************************/
22
25
 
23
26
  InvertedLists::InvertedLists(size_t nlist, size_t code_size)
24
- : nlist(nlist), code_size(code_size) {}
27
+ : nlist(nlist), code_size(code_size), use_iterator(false) {}
25
28
 
26
29
  InvertedLists::~InvertedLists() {}
27
30
 
28
- InvertedLists::idx_t InvertedLists::get_single_id(size_t list_no, size_t offset)
29
- const {
31
+ bool InvertedLists::is_empty(size_t list_no) const {
32
+ return use_iterator
33
+ ? !std::unique_ptr<InvertedListsIterator>(get_iterator(list_no))
34
+ ->is_available()
35
+ : list_size(list_no) == 0;
36
+ }
37
+
38
+ idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const {
30
39
  assert(offset < list_size(list_no));
31
40
  const idx_t* ids = get_ids(list_no);
32
41
  idx_t id = ids[offset];
@@ -67,6 +76,10 @@ void InvertedLists::reset() {
67
76
  }
68
77
  }
69
78
 
79
+ InvertedListsIterator* InvertedLists::get_iterator(size_t /*list_no*/) const {
80
+ FAISS_THROW_MSG("get_iterator is not supported");
81
+ }
82
+
70
83
  void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
71
84
  #pragma omp parallel for
72
85
  for (idx_t i = 0; i < nlist; i++) {
@@ -87,6 +100,98 @@ void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
87
100
  }
88
101
  }
89
102
 
103
+ size_t InvertedLists::copy_subset_to(
104
+ InvertedLists& oivf,
105
+ subset_type_t subset_type,
106
+ idx_t a1,
107
+ idx_t a2) const {
108
+ FAISS_THROW_IF_NOT(nlist == oivf.nlist);
109
+ FAISS_THROW_IF_NOT(code_size == oivf.code_size);
110
+ FAISS_THROW_IF_NOT_FMT(
111
+ subset_type >= 0 && subset_type <= 4,
112
+ "subset type %d not implemented",
113
+ subset_type);
114
+ size_t accu_n = 0;
115
+ size_t accu_a1 = 0;
116
+ size_t accu_a2 = 0;
117
+ size_t n_added = 0;
118
+
119
+ size_t ntotal = 0;
120
+ if (subset_type == 2) {
121
+ ntotal = compute_ntotal();
122
+ }
123
+
124
+ for (idx_t list_no = 0; list_no < nlist; list_no++) {
125
+ size_t n = list_size(list_no);
126
+ ScopedIds ids_in(this, list_no);
127
+
128
+ if (subset_type == SUBSET_TYPE_ID_RANGE) {
129
+ for (idx_t i = 0; i < n; i++) {
130
+ idx_t id = ids_in[i];
131
+ if (a1 <= id && id < a2) {
132
+ oivf.add_entry(
133
+ list_no,
134
+ get_single_id(list_no, i),
135
+ ScopedCodes(this, list_no, i).get());
136
+ n_added++;
137
+ }
138
+ }
139
+ } else if (subset_type == SUBSET_TYPE_ID_MOD) {
140
+ for (idx_t i = 0; i < n; i++) {
141
+ idx_t id = ids_in[i];
142
+ if (id % a1 == a2) {
143
+ oivf.add_entry(
144
+ list_no,
145
+ get_single_id(list_no, i),
146
+ ScopedCodes(this, list_no, i).get());
147
+ n_added++;
148
+ }
149
+ }
150
+ } else if (subset_type == SUBSET_TYPE_ELEMENT_RANGE) {
151
+ // see what is allocated to a1 and to a2
152
+ size_t next_accu_n = accu_n + n;
153
+ size_t next_accu_a1 = next_accu_n * a1 / ntotal;
154
+ size_t i1 = next_accu_a1 - accu_a1;
155
+ size_t next_accu_a2 = next_accu_n * a2 / ntotal;
156
+ size_t i2 = next_accu_a2 - accu_a2;
157
+
158
+ for (idx_t i = i1; i < i2; i++) {
159
+ oivf.add_entry(
160
+ list_no,
161
+ get_single_id(list_no, i),
162
+ ScopedCodes(this, list_no, i).get());
163
+ }
164
+
165
+ n_added += i2 - i1;
166
+ accu_a1 = next_accu_a1;
167
+ accu_a2 = next_accu_a2;
168
+ } else if (subset_type == SUBSET_TYPE_INVLIST_FRACTION) {
169
+ size_t i1 = n * a2 / a1;
170
+ size_t i2 = n * (a2 + 1) / a1;
171
+
172
+ for (idx_t i = i1; i < i2; i++) {
173
+ oivf.add_entry(
174
+ list_no,
175
+ get_single_id(list_no, i),
176
+ ScopedCodes(this, list_no, i).get());
177
+ }
178
+
179
+ n_added += i2 - i1;
180
+ } else if (subset_type == SUBSET_TYPE_INVLIST) {
181
+ if (list_no >= a1 && list_no < a2) {
182
+ oivf.add_entries(
183
+ list_no,
184
+ n,
185
+ ScopedIds(this, list_no).get(),
186
+ ScopedCodes(this, list_no).get());
187
+ n_added += n;
188
+ }
189
+ }
190
+ accu_n += n;
191
+ }
192
+ return n_added;
193
+ }
194
+
90
195
  double InvertedLists::imbalance_factor() const {
91
196
  std::vector<int> hist(nlist);
92
197
 
@@ -109,7 +214,9 @@ void InvertedLists::print_stats() const {
109
214
  }
110
215
  for (size_t i = 0; i < sizes.size(); i++) {
111
216
  if (sizes[i]) {
112
- printf("list size in < %d: %d instances\n", 1 << i, sizes[i]);
217
+ printf("list size in < %zu: %d instances\n",
218
+ static_cast<size_t>(1) << i,
219
+ sizes[i]);
113
220
  }
114
221
  }
115
222
  }
@@ -158,7 +265,7 @@ const uint8_t* ArrayInvertedLists::get_codes(size_t list_no) const {
158
265
  return codes[list_no].data();
159
266
  }
160
267
 
161
- const InvertedLists::idx_t* ArrayInvertedLists::get_ids(size_t list_no) const {
268
+ const idx_t* ArrayInvertedLists::get_ids(size_t list_no) const {
162
269
  assert(list_no < nlist);
163
270
  return ids[list_no].data();
164
271
  }
@@ -267,7 +374,7 @@ void HStackInvertedLists::release_codes(size_t, const uint8_t* codes) const {
267
374
  delete[] codes;
268
375
  }
269
376
 
270
- const Index::idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
377
+ const idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
271
378
  idx_t *ids = new idx_t[list_size(list_no)], *c = ids;
272
379
 
273
380
  for (int i = 0; i < ils.size(); i++) {
@@ -281,8 +388,7 @@ const Index::idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
281
388
  return ids;
282
389
  }
283
390
 
284
- Index::idx_t HStackInvertedLists::get_single_id(size_t list_no, size_t offset)
285
- const {
391
+ idx_t HStackInvertedLists::get_single_id(size_t list_no, size_t offset) const {
286
392
  for (int i = 0; i < ils.size(); i++) {
287
393
  const InvertedLists* il = ils[i];
288
394
  size_t sz = il->list_size(list_no);
@@ -312,8 +418,6 @@ void HStackInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist)
312
418
 
313
419
  namespace {
314
420
 
315
- using idx_t = InvertedLists::idx_t;
316
-
317
421
  idx_t translate_list_no(const SliceInvertedLists* sil, idx_t list_no) {
318
422
  FAISS_THROW_IF_NOT(list_no >= 0 && list_no < sil->nlist);
319
423
  return list_no + sil->i0;
@@ -349,12 +453,11 @@ void SliceInvertedLists::release_codes(size_t list_no, const uint8_t* codes)
349
453
  return il->release_codes(translate_list_no(this, list_no), codes);
350
454
  }
351
455
 
352
- const Index::idx_t* SliceInvertedLists::get_ids(size_t list_no) const {
456
+ const idx_t* SliceInvertedLists::get_ids(size_t list_no) const {
353
457
  return il->get_ids(translate_list_no(this, list_no));
354
458
  }
355
459
 
356
- Index::idx_t SliceInvertedLists::get_single_id(size_t list_no, size_t offset)
357
- const {
460
+ idx_t SliceInvertedLists::get_single_id(size_t list_no, size_t offset) const {
358
461
  return il->get_single_id(translate_list_no(this, list_no), offset);
359
462
  }
360
463
 
@@ -380,8 +483,6 @@ void SliceInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist)
380
483
 
381
484
  namespace {
382
485
 
383
- using idx_t = InvertedLists::idx_t;
384
-
385
486
  // find the invlist this number belongs to
386
487
  int translate_list_no(const VStackInvertedLists* vil, idx_t list_no) {
387
488
  FAISS_THROW_IF_NOT(list_no >= 0 && list_no < vil->nlist);
@@ -449,14 +550,13 @@ void VStackInvertedLists::release_codes(size_t list_no, const uint8_t* codes)
449
550
  return ils[i]->release_codes(list_no, codes);
450
551
  }
451
552
 
452
- const Index::idx_t* VStackInvertedLists::get_ids(size_t list_no) const {
553
+ const idx_t* VStackInvertedLists::get_ids(size_t list_no) const {
453
554
  int i = translate_list_no(this, list_no);
454
555
  list_no -= cumsz[i];
455
556
  return ils[i]->get_ids(list_no);
456
557
  }
457
558
 
458
- Index::idx_t VStackInvertedLists::get_single_id(size_t list_no, size_t offset)
459
- const {
559
+ idx_t VStackInvertedLists::get_single_id(size_t list_no, size_t offset) const {
460
560
  int i = translate_list_no(this, list_no);
461
561
  list_no -= cumsz[i];
462
562
  return ils[i]->get_single_id(list_no, offset);
@@ -15,11 +15,18 @@
15
15
  * the interface.
16
16
  */
17
17
 
18
- #include <faiss/Index.h>
18
+ #include <faiss/MetricType.h>
19
19
  #include <vector>
20
20
 
21
21
  namespace faiss {
22
22
 
23
+ struct InvertedListsIterator {
24
+ virtual ~InvertedListsIterator();
25
+ virtual bool is_available() const = 0;
26
+ virtual void next() = 0;
27
+ virtual std::pair<idx_t, const uint8_t*> get_id_and_codes() = 0;
28
+ };
29
+
23
30
  /** Table of inverted lists
24
31
  * multithreading rules:
25
32
  * - concurrent read accesses are allowed
@@ -28,13 +35,14 @@ namespace faiss {
28
35
  * are allowed
29
36
  */
30
37
  struct InvertedLists {
31
- typedef Index::idx_t idx_t;
32
-
33
38
  size_t nlist; ///< number of possible key values
34
39
  size_t code_size; ///< code size per vector in bytes
40
+ bool use_iterator;
35
41
 
36
42
  InvertedLists(size_t nlist, size_t code_size);
37
43
 
44
+ virtual ~InvertedLists();
45
+
38
46
  /// used for BlockInvertedLists, where the codes are packed into groups
39
47
  /// and the individual code size is meaningless
40
48
  static const size_t INVALID_CODE_SIZE = static_cast<size_t>(-1);
@@ -42,9 +50,15 @@ struct InvertedLists {
42
50
  /*************************
43
51
  * Read only functions */
44
52
 
53
+ // check if the list is empty
54
+ bool is_empty(size_t list_no) const;
55
+
45
56
  /// get the size of a list
46
57
  virtual size_t list_size(size_t list_no) const = 0;
47
58
 
59
+ /// get iterable for lists that use_iterator
60
+ virtual InvertedListsIterator* get_iterator(size_t list_no) const;
61
+
48
62
  /** get the codes for an inverted list
49
63
  * must be released by release_codes
50
64
  *
@@ -105,10 +119,36 @@ struct InvertedLists {
105
119
 
106
120
  virtual void reset();
107
121
 
122
+ /*************************
123
+ * high level functions */
124
+
108
125
  /// move all entries from oivf (empty on output)
109
126
  void merge_from(InvertedLists* oivf, size_t add_id);
110
127
 
111
- virtual ~InvertedLists();
128
+ // how to copy a subset of elements from the inverted lists
129
+ // This depends on two integers, a1 and a2.
130
+ enum subset_type_t : int {
131
+ // depends on IDs
132
+ SUBSET_TYPE_ID_RANGE = 0, // copies ids in [a1, a2)
133
+ SUBSET_TYPE_ID_MOD = 1, // copies ids if id % a1 == a2
134
+ // depends on order within invlists
135
+ SUBSET_TYPE_ELEMENT_RANGE =
136
+ 2, // copies fractions of invlists so that a1 elements are left
137
+ // before and a2 after
138
+ SUBSET_TYPE_INVLIST_FRACTION =
139
+ 3, // take fraction a2 out of a1 from each invlist, 0 <= a2 < a1
140
+ // copy only inverted lists a1:a2
141
+ SUBSET_TYPE_INVLIST = 4
142
+ };
143
+
144
+ /** copy a subset of the entries index to the other index
145
+ * @return number of entries copied
146
+ */
147
+ size_t copy_subset_to(
148
+ InvertedLists& other,
149
+ subset_type_t subset_type,
150
+ idx_t a1,
151
+ idx_t a2) const;
112
152
 
113
153
  /*************************
114
154
  * statistics */
@@ -154,7 +154,7 @@ struct OnDiskInvertedLists::OngoingPrefetch {
154
154
  const OnDiskInvertedLists* od = pf->od;
155
155
  od->locks->lock_1(list_no);
156
156
  size_t n = od->list_size(list_no);
157
- const Index::idx_t* idx = od->get_ids(list_no);
157
+ const idx_t* idx = od->get_ids(list_no);
158
158
  const uint8_t* codes = od->get_codes(list_no);
159
159
  int cs = 0;
160
160
  for (size_t i = 0; i < n; i++) {
@@ -389,7 +389,7 @@ const uint8_t* OnDiskInvertedLists::get_codes(size_t list_no) const {
389
389
  return ptr + lists[list_no].offset;
390
390
  }
391
391
 
392
- const Index::idx_t* OnDiskInvertedLists::get_ids(size_t list_no) const {
392
+ const idx_t* OnDiskInvertedLists::get_ids(size_t list_no) const {
393
393
  if (lists[list_no].offset == INVALID_OFFSET) {
394
394
  return nullptr;
395
395
  }
@@ -781,7 +781,7 @@ InvertedLists* OnDiskInvertedListsIOHook::read_ArrayInvertedLists(
781
781
  OnDiskInvertedLists::List& l = ails->lists[i];
782
782
  l.size = l.capacity = sizes[i];
783
783
  l.offset = o;
784
- o += l.size * (sizeof(OnDiskInvertedLists::idx_t) + ails->code_size);
784
+ o += l.size * (sizeof(idx_t) + ails->code_size);
785
785
  }
786
786
  // resume normal reading of file
787
787
  fseek(fdesc, o, SEEK_SET);
@@ -31,7 +31,7 @@ struct OnDiskOneList {
31
31
 
32
32
  /** On-disk storage of inverted lists.
33
33
  *
34
- * The data is stored in a mmapped chunk of memory (base ptointer ptr,
34
+ * The data is stored in a mmapped chunk of memory (base pointer ptr,
35
35
  * size totsize). Each list is a range of memory that contains (object
36
36
  * List) that contains:
37
37
  *
@@ -118,7 +118,7 @@ PyCallbackIDSelector::PyCallbackIDSelector(PyObject* callback)
118
118
  Py_INCREF(callback);
119
119
  }
120
120
 
121
- bool PyCallbackIDSelector::is_member(idx_t id) const {
121
+ bool PyCallbackIDSelector::is_member(faiss::idx_t id) const {
122
122
  FAISS_THROW_IF_NOT((id >> 32) == 0);
123
123
  PyThreadLock gil;
124
124
  PyObject* result = PyObject_CallFunction(callback, "(n)", int(id));
@@ -54,7 +54,7 @@ struct PyCallbackIDSelector : faiss::IDSelector {
54
54
 
55
55
  explicit PyCallbackIDSelector(PyObject* callback);
56
56
 
57
- bool is_member(idx_t id) const override;
57
+ bool is_member(faiss::idx_t id) const override;
58
58
 
59
59
  ~PyCallbackIDSelector() override;
60
60
  };
@@ -98,7 +98,9 @@ struct AlignedTableTightAlloc {
98
98
  AlignedTableTightAlloc<T, A>& operator=(
99
99
  const AlignedTableTightAlloc<T, A>& other) {
100
100
  resize(other.numel);
101
- memcpy(ptr, other.ptr, sizeof(T) * numel);
101
+ if (numel > 0) {
102
+ memcpy(ptr, other.ptr, sizeof(T) * numel);
103
+ }
102
104
  return *this;
103
105
  }
104
106
 
@@ -9,6 +9,7 @@
9
9
 
10
10
  /* Function for soft heap */
11
11
 
12
+ #include <faiss/impl/FaissAssert.h>
12
13
  #include <faiss/utils/Heap.h>
13
14
 
14
15
  namespace faiss {
@@ -32,7 +33,7 @@ void HeapArray<C>::addn(size_t nj, const T* vin, TI j0, size_t i0, int64_t ni) {
32
33
  if (ni == -1)
33
34
  ni = nh;
34
35
  assert(i0 >= 0 && i0 + ni <= nh);
35
- #pragma omp parallel for
36
+ #pragma omp parallel for if (ni * nj > 100000)
36
37
  for (int64_t i = i0; i < i0 + ni; i++) {
37
38
  T* __restrict simi = get_val(i);
38
39
  TI* __restrict idxi = get_ids(i);
@@ -62,7 +63,7 @@ void HeapArray<C>::addn_with_ids(
62
63
  if (ni == -1)
63
64
  ni = nh;
64
65
  assert(i0 >= 0 && i0 + ni <= nh);
65
- #pragma omp parallel for
66
+ #pragma omp parallel for if (ni * nj > 100000)
66
67
  for (int64_t i = i0; i < i0 + ni; i++) {
67
68
  T* __restrict simi = get_val(i);
68
69
  TI* __restrict idxi = get_ids(i);
@@ -78,9 +79,38 @@ void HeapArray<C>::addn_with_ids(
78
79
  }
79
80
  }
80
81
 
82
+ template <typename C>
83
+ void HeapArray<C>::addn_query_subset_with_ids(
84
+ size_t nsubset,
85
+ const TI* subset,
86
+ size_t nj,
87
+ const T* vin,
88
+ const TI* id_in,
89
+ int64_t id_stride) {
90
+ FAISS_THROW_IF_NOT_MSG(id_in, "anonymous ids not supported");
91
+ if (id_stride < 0) {
92
+ id_stride = nj;
93
+ }
94
+ #pragma omp parallel for if (nsubset * nj > 100000)
95
+ for (int64_t si = 0; si < nsubset; si++) {
96
+ T i = subset[si];
97
+ T* __restrict simi = get_val(i);
98
+ TI* __restrict idxi = get_ids(i);
99
+ const T* ip_line = vin + si * nj;
100
+ const TI* id_line = id_in + si * id_stride;
101
+
102
+ for (size_t j = 0; j < nj; j++) {
103
+ T ip = ip_line[j];
104
+ if (C::cmp(simi[0], ip)) {
105
+ heap_replace_top<C>(k, simi, idxi, ip, id_line[j]);
106
+ }
107
+ }
108
+ }
109
+ }
110
+
81
111
  template <typename C>
82
112
  void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {
83
- #pragma omp parallel for
113
+ #pragma omp parallel for if (nh * k > 100000)
84
114
  for (int64_t j = 0; j < nh; j++) {
85
115
  int64_t imin = -1;
86
116
  typename C::T xval = C::Crev::neutral();
@@ -109,4 +139,110 @@ template struct HeapArray<CMax<float, int64_t>>;
109
139
  template struct HeapArray<CMin<int, int64_t>>;
110
140
  template struct HeapArray<CMax<int, int64_t>>;
111
141
 
142
+ /**********************************************************
143
+ * merge knn search results
144
+ **********************************************************/
145
+
146
+ /** Merge result tables from several shards. The per-shard results are assumed
147
+ * to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
148
+ * element heap because we want the best (ie. lowest for L2) result to be on
149
+ * top, not the worst.
150
+ *
151
+ * @param all_distances size (nshard, n, k)
152
+ * @param all_labels size (nshard, n, k)
153
+ * @param distances output distances, size (n, k)
154
+ * @param labels output labels, size (n, k)
155
+ */
156
+ template <class idx_t, class C>
157
+ void merge_knn_results(
158
+ size_t n,
159
+ size_t k,
160
+ typename C::TI nshard,
161
+ const typename C::T* all_distances,
162
+ const idx_t* all_labels,
163
+ typename C::T* distances,
164
+ idx_t* labels) {
165
+ using distance_t = typename C::T;
166
+ if (k == 0) {
167
+ return;
168
+ }
169
+ long stride = n * k;
170
+ #pragma omp parallel if (n * nshard * k > 100000)
171
+ {
172
+ std::vector<int> buf(2 * nshard);
173
+ // index in each shard's result list
174
+ int* pointer = buf.data();
175
+ // (shard_ids, heap_vals): heap that indexes
176
+ // shard -> current distance for this shard
177
+ int* shard_ids = pointer + nshard;
178
+ std::vector<distance_t> buf2(nshard);
179
+ distance_t* heap_vals = buf2.data();
180
+ #pragma omp for
181
+ for (long i = 0; i < n; i++) {
182
+ // the heap maps values to the shard where they are
183
+ // produced.
184
+ const distance_t* D_in = all_distances + i * k;
185
+ const idx_t* I_in = all_labels + i * k;
186
+ int heap_size = 0;
187
+
188
+ // push the first element of each shard (if not -1)
189
+ for (long s = 0; s < nshard; s++) {
190
+ pointer[s] = 0;
191
+ if (I_in[stride * s] >= 0) {
192
+ heap_push<C>(
193
+ ++heap_size,
194
+ heap_vals,
195
+ shard_ids,
196
+ D_in[stride * s],
197
+ s);
198
+ }
199
+ }
200
+
201
+ distance_t* D = distances + i * k;
202
+ idx_t* I = labels + i * k;
203
+
204
+ int j;
205
+ for (j = 0; j < k && heap_size > 0; j++) {
206
+ // pop element from best shard
207
+ int s = shard_ids[0]; // top of heap
208
+ int& p = pointer[s];
209
+ D[j] = heap_vals[0];
210
+ I[j] = I_in[stride * s + p];
211
+
212
+ // pop from shard, advance pointer for this shard
213
+ heap_pop<C>(heap_size--, heap_vals, shard_ids);
214
+ p++;
215
+ if (p < k && I_in[stride * s + p] >= 0) {
216
+ heap_push<C>(
217
+ ++heap_size,
218
+ heap_vals,
219
+ shard_ids,
220
+ D_in[stride * s + p],
221
+ s);
222
+ }
223
+ }
224
+ for (; j < k; j++) {
225
+ I[j] = -1;
226
+ D[j] = C::Crev::neutral();
227
+ }
228
+ }
229
+ }
230
+ }
231
+
232
+ // explicit instanciations
233
+ #define INSTANTIATE(C, distance_t) \
234
+ template void merge_knn_results<int64_t, C<distance_t, int>>( \
235
+ size_t, \
236
+ size_t, \
237
+ int, \
238
+ const distance_t*, \
239
+ const int64_t*, \
240
+ distance_t*, \
241
+ int64_t*);
242
+
243
+ INSTANTIATE(CMin, float);
244
+ INSTANTIATE(CMax, float);
245
+ INSTANTIATE(CMin, int32_t);
246
+ INSTANTIATE(CMax, int32_t);
247
+
112
248
  } // namespace faiss
@@ -413,6 +413,19 @@ struct HeapArray {
413
413
  size_t i0 = 0,
414
414
  int64_t ni = -1);
415
415
 
416
+ /** same as addn_with_ids, but for just a subset of queries
417
+ *
418
+ * @param nsubset number of query entries to update
419
+ * @param subset indexes of queries to update, in 0..nh-1, size nsubset
420
+ */
421
+ void addn_query_subset_with_ids(
422
+ size_t nsubset,
423
+ const TI* subset,
424
+ size_t nj,
425
+ const T* vin,
426
+ const TI* id_in = nullptr,
427
+ int64_t id_stride = 0);
428
+
416
429
  /// reorder all the heaps
417
430
  void reorder();
418
431
 
@@ -431,7 +444,7 @@ typedef HeapArray<CMin<int, int64_t>> int_minheap_array_t;
431
444
  typedef HeapArray<CMax<float, int64_t>> float_maxheap_array_t;
432
445
  typedef HeapArray<CMax<int, int64_t>> int_maxheap_array_t;
433
446
 
434
- // The heap templates are instanciated explicitly in Heap.cpp
447
+ // The heap templates are instantiated explicitly in Heap.cpp
435
448
 
436
449
  /*********************************************************************
437
450
  * Indirect heaps: instead of having
@@ -492,6 +505,27 @@ inline void indirect_heap_push(
492
505
  bh_ids[i] = id;
493
506
  }
494
507
 
508
+ /** Merge result tables from several shards. The per-shard results are assumed
509
+ * to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
510
+ * element heap because we want the best (ie. lowest for L2) result to be on
511
+ * top, not the worst. Also, it needs to hold an index of a shard id (ie.
512
+ * usually int32 is more than enough).
513
+ *
514
+ * @param all_distances size (nshard, n, k)
515
+ * @param all_labels size (nshard, n, k)
516
+ * @param distances output distances, size (n, k)
517
+ * @param labels output labels, size (n, k)
518
+ */
519
+ template <class idx_t, class C>
520
+ void merge_knn_results(
521
+ size_t n,
522
+ size_t k,
523
+ typename C::TI nshard,
524
+ const typename C::T* all_distances,
525
+ const idx_t* all_labels,
526
+ typename C::T* distances,
527
+ idx_t* labels);
528
+
495
529
  } // namespace faiss
496
530
 
497
531
  #endif /* FAISS_Heap_h */