faiss 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (77) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +18 -18
  4. data/README.md +1 -1
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/Clustering.cpp +318 -53
  7. data/vendor/faiss/Clustering.h +39 -11
  8. data/vendor/faiss/DirectMap.cpp +267 -0
  9. data/vendor/faiss/DirectMap.h +120 -0
  10. data/vendor/faiss/IVFlib.cpp +24 -4
  11. data/vendor/faiss/IVFlib.h +4 -0
  12. data/vendor/faiss/Index.h +5 -24
  13. data/vendor/faiss/Index2Layer.cpp +0 -1
  14. data/vendor/faiss/IndexBinary.h +7 -3
  15. data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
  16. data/vendor/faiss/IndexBinaryFlat.h +3 -0
  17. data/vendor/faiss/IndexBinaryHash.cpp +492 -0
  18. data/vendor/faiss/IndexBinaryHash.h +116 -0
  19. data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
  20. data/vendor/faiss/IndexBinaryIVF.h +14 -4
  21. data/vendor/faiss/IndexFlat.h +2 -1
  22. data/vendor/faiss/IndexHNSW.cpp +68 -16
  23. data/vendor/faiss/IndexHNSW.h +3 -3
  24. data/vendor/faiss/IndexIVF.cpp +72 -76
  25. data/vendor/faiss/IndexIVF.h +24 -5
  26. data/vendor/faiss/IndexIVFFlat.cpp +19 -54
  27. data/vendor/faiss/IndexIVFFlat.h +1 -11
  28. data/vendor/faiss/IndexIVFPQ.cpp +49 -26
  29. data/vendor/faiss/IndexIVFPQ.h +9 -10
  30. data/vendor/faiss/IndexIVFPQR.cpp +2 -2
  31. data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
  32. data/vendor/faiss/IndexLSH.h +4 -1
  33. data/vendor/faiss/IndexPreTransform.cpp +0 -1
  34. data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
  35. data/vendor/faiss/InvertedLists.cpp +0 -2
  36. data/vendor/faiss/MetaIndexes.cpp +0 -1
  37. data/vendor/faiss/MetricType.h +36 -0
  38. data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
  39. data/vendor/faiss/c_api/Clustering_c.h +11 -5
  40. data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
  41. data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
  42. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
  43. data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
  44. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
  45. data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
  46. data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
  47. data/vendor/faiss/gpu/GpuDistance.h +93 -0
  48. data/vendor/faiss/gpu/GpuIndex.h +7 -0
  49. data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
  50. data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
  51. data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
  52. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
  53. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
  54. data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
  55. data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
  56. data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
  57. data/vendor/faiss/impl/HNSW.cpp +0 -1
  58. data/vendor/faiss/impl/PolysemousTraining.h +5 -5
  59. data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
  60. data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
  61. data/vendor/faiss/impl/ProductQuantizer.h +42 -47
  62. data/vendor/faiss/impl/index_read.cpp +103 -7
  63. data/vendor/faiss/impl/index_write.cpp +101 -5
  64. data/vendor/faiss/impl/io.cpp +111 -1
  65. data/vendor/faiss/impl/io.h +38 -0
  66. data/vendor/faiss/index_factory.cpp +0 -1
  67. data/vendor/faiss/tests/test_merge.cpp +0 -1
  68. data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
  69. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
  70. data/vendor/faiss/utils/distances.cpp +4 -5
  71. data/vendor/faiss/utils/distances_simd.cpp +0 -1
  72. data/vendor/faiss/utils/hamming.cpp +85 -3
  73. data/vendor/faiss/utils/hamming.h +20 -0
  74. data/vendor/faiss/utils/utils.cpp +0 -96
  75. data/vendor/faiss/utils/utils.h +0 -15
  76. metadata +11 -3
  77. data/lib/faiss/ext.bundle +0 -0
@@ -46,8 +46,7 @@ struct IndexBinaryIVF : IndexBinary {
46
46
  bool use_heap = true;
47
47
 
48
48
  /// map for direct access to the elements. Enables reconstruct().
49
- bool maintain_direct_map;
50
- std::vector<idx_t> direct_map;
49
+ DirectMap direct_map;
51
50
 
52
51
  IndexBinary *quantizer; ///< quantizer that maps vectors to inverted lists
53
52
  size_t nlist; ///< number of possible key values
@@ -110,8 +109,11 @@ struct IndexBinaryIVF : IndexBinary {
110
109
  bool store_pairs=false) const;
111
110
 
112
111
  /** assign the vectors, then call search_preassign */
113
- virtual void search(idx_t n, const uint8_t *x, idx_t k,
114
- int32_t *distances, idx_t *labels) const override;
112
+ void search(idx_t n, const uint8_t *x, idx_t k,
113
+ int32_t *distances, idx_t *labels) const override;
114
+
115
+ void range_search(idx_t n, const uint8_t *x, int radius,
116
+ RangeSearchResult *result) const override;
115
117
 
116
118
  void reconstruct(idx_t key, uint8_t *recons) const override;
117
119
 
@@ -168,6 +170,8 @@ struct IndexBinaryIVF : IndexBinary {
168
170
  */
169
171
  void make_direct_map(bool new_maintain_direct_map=true);
170
172
 
173
+ void set_direct_map_type (DirectMap::Type type);
174
+
171
175
  void replace_invlists(InvertedLists *il, bool own=false);
172
176
  };
173
177
 
@@ -201,6 +205,12 @@ struct BinaryInvertedListScanner {
201
205
  int32_t *distances, idx_t *labels,
202
206
  size_t k) const = 0;
203
207
 
208
+ virtual void scan_codes_range (size_t n,
209
+ const uint8_t *codes,
210
+ const idx_t *ids,
211
+ int radius,
212
+ RangeQueryResult &result) const = 0;
213
+
204
214
  virtual ~BinaryInvertedListScanner () {}
205
215
 
206
216
  };
@@ -19,6 +19,7 @@ namespace faiss {
19
19
 
20
20
  /** Index that stores the full vectors and performs exhaustive search */
21
21
  struct IndexFlat: Index {
22
+
22
23
  /// database vectors, size ntotal * d
23
24
  std::vector<float> xb;
24
25
 
@@ -144,7 +145,7 @@ struct IndexRefineFlat: Index {
144
145
  };
145
146
 
146
147
 
147
- /// optimized version for 1D "vectors"
148
+ /// optimized version for 1D "vectors".
148
149
  struct IndexFlat1D:IndexFlatL2 {
149
150
  bool continuous_update; ///< is the permutation updated continuously?
150
151
 
@@ -26,7 +26,6 @@
26
26
  #include <stdint.h>
27
27
 
28
28
  #ifdef __SSE__
29
- #include <immintrin.h>
30
29
  #endif
31
30
 
32
31
  #include <faiss/utils/distances.h>
@@ -55,7 +54,6 @@ namespace faiss {
55
54
  using idx_t = Index::idx_t;
56
55
  using MinimaxHeap = HNSW::MinimaxHeap;
57
56
  using storage_idx_t = HNSW::storage_idx_t;
58
- using NodeDistCloser = HNSW::NodeDistCloser;
59
57
  using NodeDistFarther = HNSW::NodeDistFarther;
60
58
 
61
59
  HNSWStats hnsw_stats;
@@ -67,6 +65,50 @@ HNSWStats hnsw_stats;
67
65
  namespace {
68
66
 
69
67
 
68
+ /* Wrap the distance computer into one that negates the
69
+ distances. This makes supporting INNER_PRODUCE search easier */
70
+
71
+ struct NegativeDistanceComputer: DistanceComputer {
72
+
73
+ /// owned by this
74
+ DistanceComputer *basedis;
75
+
76
+ explicit NegativeDistanceComputer(DistanceComputer *basedis):
77
+ basedis(basedis)
78
+ {}
79
+
80
+ void set_query(const float *x) override {
81
+ basedis->set_query(x);
82
+ }
83
+
84
+ /// compute distance of vector i to current query
85
+ float operator () (idx_t i) override {
86
+ return -(*basedis)(i);
87
+ }
88
+
89
+ /// compute distance between two stored vectors
90
+ float symmetric_dis (idx_t i, idx_t j) override {
91
+ return -basedis->symmetric_dis(i, j);
92
+ }
93
+
94
+ virtual ~NegativeDistanceComputer ()
95
+ {
96
+ delete basedis;
97
+ }
98
+
99
+ };
100
+
101
+ DistanceComputer *storage_distance_computer(const Index *storage)
102
+ {
103
+ if (storage->metric_type == METRIC_INNER_PRODUCT) {
104
+ return new NegativeDistanceComputer(storage->get_distance_computer());
105
+ } else {
106
+ return storage->get_distance_computer();
107
+ }
108
+ }
109
+
110
+
111
+
70
112
  void hnsw_add_vertices(IndexHNSW &index_hnsw,
71
113
  size_t n0,
72
114
  size_t n, const float *x,
@@ -152,7 +194,7 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
152
194
  VisitedTable vt (ntotal);
153
195
 
154
196
  DistanceComputer *dis =
155
- index_hnsw.storage->get_distance_computer();
197
+ storage_distance_computer (index_hnsw.storage);
156
198
  ScopeDeleter1<DistanceComputer> del(dis);
157
199
  int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
158
200
  size_t counter = 0;
@@ -210,8 +252,8 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
210
252
  * IndexHNSW implementation
211
253
  **************************************************************/
212
254
 
213
- IndexHNSW::IndexHNSW(int d, int M):
214
- Index(d, METRIC_L2),
255
+ IndexHNSW::IndexHNSW(int d, int M, MetricType metric):
256
+ Index(d, metric),
215
257
  hnsw(M),
216
258
  own_fields(false),
217
259
  storage(nullptr),
@@ -258,7 +300,8 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
258
300
  #pragma omp parallel reduction(+ : nreorder)
259
301
  {
260
302
  VisitedTable vt (ntotal);
261
- DistanceComputer *dis = storage->get_distance_computer();
303
+
304
+ DistanceComputer *dis = storage_distance_computer(storage);
262
305
  ScopeDeleter1<DistanceComputer> del(dis);
263
306
 
264
307
  #pragma omp for
@@ -290,6 +333,14 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
290
333
  }
291
334
  InterruptCallback::check ();
292
335
  }
336
+
337
+ if (metric_type == METRIC_INNER_PRODUCT) {
338
+ // we need to revert the negated distances
339
+ for (size_t i = 0; i < k * n; i++) {
340
+ distances[i] = -distances[i];
341
+ }
342
+ }
343
+
293
344
  hnsw_stats.nreorder += nreorder;
294
345
  }
295
346
 
@@ -323,7 +374,7 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size)
323
374
  {
324
375
  #pragma omp parallel
325
376
  {
326
- DistanceComputer *dis = storage->get_distance_computer();
377
+ DistanceComputer *dis = storage_distance_computer(storage);
327
378
  ScopeDeleter1<DistanceComputer> del(dis);
328
379
 
329
380
  #pragma omp for
@@ -367,7 +418,7 @@ void IndexHNSW::search_level_0(
367
418
  storage_idx_t ntotal = hnsw.levels.size();
368
419
  #pragma omp parallel
369
420
  {
370
- DistanceComputer *qdis = storage->get_distance_computer();
421
+ DistanceComputer *qdis = storage_distance_computer(storage);
371
422
  ScopeDeleter1<DistanceComputer> del(qdis);
372
423
 
373
424
  VisitedTable vt (ntotal);
@@ -436,7 +487,7 @@ void IndexHNSW::init_level_0_from_knngraph(
436
487
 
437
488
  #pragma omp parallel for
438
489
  for (idx_t i = 0; i < ntotal; i++) {
439
- DistanceComputer *qdis = storage->get_distance_computer();
490
+ DistanceComputer *qdis = storage_distance_computer(storage);
440
491
  float vec[d];
441
492
  storage->reconstruct(i, vec);
442
493
  qdis->set_query(vec);
@@ -480,7 +531,7 @@ void IndexHNSW::init_level_0_from_entry_points(
480
531
  {
481
532
  VisitedTable vt (ntotal);
482
533
 
483
- DistanceComputer *dis = storage->get_distance_computer();
534
+ DistanceComputer *dis = storage_distance_computer(storage);
484
535
  ScopeDeleter1<DistanceComputer> del(dis);
485
536
  float vec[storage->d];
486
537
 
@@ -518,7 +569,7 @@ void IndexHNSW::reorder_links()
518
569
  std::vector<float> distances (M);
519
570
  std::vector<size_t> order (M);
520
571
  std::vector<storage_idx_t> tmp (M);
521
- DistanceComputer *dis = storage->get_distance_computer();
572
+ DistanceComputer *dis = storage_distance_computer(storage);
522
573
  ScopeDeleter1<DistanceComputer> del(dis);
523
574
 
524
575
  #pragma omp for
@@ -826,8 +877,8 @@ IndexHNSWFlat::IndexHNSWFlat()
826
877
  is_trained = true;
827
878
  }
828
879
 
829
- IndexHNSWFlat::IndexHNSWFlat(int d, int M):
830
- IndexHNSW(new IndexFlatL2(d), M)
880
+ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric):
881
+ IndexHNSW(new IndexFlat(d, metric), M)
831
882
  {
832
883
  own_fields = true;
833
884
  is_trained = true;
@@ -860,8 +911,9 @@ void IndexHNSWPQ::train(idx_t n, const float* x)
860
911
  **************************************************************/
861
912
 
862
913
 
863
- IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M):
864
- IndexHNSW (new IndexScalarQuantizer (d, qtype), M)
914
+ IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M,
915
+ MetricType metric):
916
+ IndexHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
865
917
  {
866
918
  is_trained = false;
867
919
  own_fields = true;
@@ -986,7 +1038,7 @@ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
986
1038
  #pragma omp parallel
987
1039
  {
988
1040
  VisitedTable vt (ntotal);
989
- DistanceComputer *dis = storage->get_distance_computer();
1041
+ DistanceComputer *dis = storage_distance_computer(storage);
990
1042
  ScopeDeleter1<DistanceComputer> del(dis);
991
1043
 
992
1044
  int candidates_size = hnsw.upper_beam;
@@ -79,7 +79,7 @@ struct IndexHNSW : Index {
79
79
 
80
80
  ReconstructFromNeighbors *reconstruct_from_neighbors;
81
81
 
82
- explicit IndexHNSW (int d = 0, int M = 32);
82
+ explicit IndexHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
83
83
  explicit IndexHNSW (Index *storage, int M = 32);
84
84
 
85
85
  ~IndexHNSW() override;
@@ -132,7 +132,7 @@ struct IndexHNSW : Index {
132
132
 
133
133
  struct IndexHNSWFlat : IndexHNSW {
134
134
  IndexHNSWFlat();
135
- IndexHNSWFlat(int d, int M);
135
+ IndexHNSWFlat(int d, int M, MetricType metric = METRIC_L2);
136
136
  };
137
137
 
138
138
  /** PQ index topped with with a HNSW structure to access elements
@@ -149,7 +149,7 @@ struct IndexHNSWPQ : IndexHNSW {
149
149
  */
150
150
  struct IndexHNSWSQ : IndexHNSW {
151
151
  IndexHNSWSQ();
152
- IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M);
152
+ IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M, MetricType metric = METRIC_L2);
153
153
  };
154
154
 
155
155
  /** 2-level code structure with fast random access
@@ -157,8 +157,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
157
157
  code_size (code_size),
158
158
  nprobe (1),
159
159
  max_codes (0),
160
- parallel_mode (0),
161
- maintain_direct_map (false)
160
+ parallel_mode (0)
162
161
  {
163
162
  FAISS_THROW_IF_NOT (d == quantizer->d);
164
163
  is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
@@ -172,8 +171,7 @@ IndexIVF::IndexIVF (Index * quantizer, size_t d,
172
171
  IndexIVF::IndexIVF ():
173
172
  invlists (nullptr), own_invlists (false),
174
173
  code_size (0),
175
- nprobe (1), max_codes (0), parallel_mode (0),
176
- maintain_direct_map (false)
174
+ nprobe (1), max_codes (0), parallel_mode (0)
177
175
  {}
178
176
 
179
177
  void IndexIVF::add (idx_t n, const float * x)
@@ -199,6 +197,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
199
197
  }
200
198
 
201
199
  FAISS_THROW_IF_NOT (is_trained);
200
+ direct_map.check_can_add (xids);
201
+
202
202
  std::unique_ptr<idx_t []> idx(new idx_t[n]);
203
203
  quantizer->assign (n, x, idx.get());
204
204
  size_t nadd = 0, nminus1 = 0;
@@ -210,6 +210,8 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
210
210
  std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
211
211
  encode_vectors (n, x, idx.get(), flat_codes.get());
212
212
 
213
+ DirectMapAdd dm_adder(direct_map, n, xids);
214
+
213
215
  #pragma omp parallel reduction(+: nadd)
214
216
  {
215
217
  int nt = omp_get_num_threads();
@@ -220,13 +222,21 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
220
222
  idx_t list_no = idx [i];
221
223
  if (list_no >= 0 && list_no % nt == rank) {
222
224
  idx_t id = xids ? xids[i] : ntotal + i;
223
- invlists->add_entry (list_no, id,
224
- flat_codes.get() + i * code_size);
225
+ size_t ofs = invlists->add_entry (
226
+ list_no, id,
227
+ flat_codes.get() + i * code_size
228
+ );
229
+
230
+ dm_adder.add (i, list_no, ofs);
231
+
225
232
  nadd++;
233
+ } else if (rank == 0 && list_no == -1) {
234
+ dm_adder.add (i, -1, 0);
226
235
  }
227
236
  }
228
237
  }
229
238
 
239
+
230
240
  if (verbose) {
231
241
  printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1);
232
242
  }
@@ -234,30 +244,18 @@ void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
234
244
  ntotal += n;
235
245
  }
236
246
 
237
-
238
- void IndexIVF::make_direct_map (bool new_maintain_direct_map)
247
+ void IndexIVF::make_direct_map (bool b)
239
248
  {
240
- // nothing to do
241
- if (new_maintain_direct_map == maintain_direct_map)
242
- return;
243
-
244
- if (new_maintain_direct_map) {
245
- direct_map.resize (ntotal, -1);
246
- for (size_t key = 0; key < nlist; key++) {
247
- size_t list_size = invlists->list_size (key);
248
- ScopedIds idlist (invlists, key);
249
-
250
- for (long ofs = 0; ofs < list_size; ofs++) {
251
- FAISS_THROW_IF_NOT_MSG (
252
- 0 <= idlist [ofs] && idlist[ofs] < ntotal,
253
- "direct map supported only for seuquential ids");
254
- direct_map [idlist [ofs]] = key << 32 | ofs;
255
- }
256
- }
249
+ if (b) {
250
+ direct_map.set_type (DirectMap::Array, invlists, ntotal);
257
251
  } else {
258
- direct_map.clear ();
252
+ direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
259
253
  }
260
- maintain_direct_map = new_maintain_direct_map;
254
+ }
255
+
256
+ void IndexIVF::set_direct_map_type (DirectMap::Type type)
257
+ {
258
+ direct_map.set_type (type, invlists, ntotal);
261
259
  }
262
260
 
263
261
 
@@ -298,10 +296,13 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
298
296
 
299
297
  bool interrupt = false;
300
298
 
299
+ int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
300
+ bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
301
+
301
302
  // don't start parallel section if single query
302
303
  bool do_parallel =
303
- parallel_mode == 0 ? n > 1 :
304
- parallel_mode == 1 ? nprobe > 1 :
304
+ pmode == 0 ? n > 1 :
305
+ pmode == 1 ? nprobe > 1 :
305
306
  nprobe * n > 1;
306
307
 
307
308
  #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
@@ -318,6 +319,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
318
319
  // intialize + reorder a result heap
319
320
 
320
321
  auto init_result = [&](float *simi, idx_t *idxi) {
322
+ if (!do_heap_init) return;
321
323
  if (metric_type == METRIC_INNER_PRODUCT) {
322
324
  heap_heapify<HeapForIP> (k, simi, idxi);
323
325
  } else {
@@ -326,6 +328,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
326
328
  };
327
329
 
328
330
  auto reorder_result = [&] (float *simi, idx_t *idxi) {
331
+ if (!do_heap_init) return;
329
332
  if (metric_type == METRIC_INNER_PRODUCT) {
330
333
  heap_reorder<HeapForIP> (k, simi, idxi);
331
334
  } else {
@@ -377,7 +380,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
377
380
  * Actual loops, depending on parallel_mode
378
381
  ****************************************************/
379
382
 
380
- if (parallel_mode == 0) {
383
+ if (pmode == 0) {
381
384
 
382
385
  #pragma omp for
383
386
  for (size_t i = 0; i < n; i++) {
@@ -417,7 +420,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
417
420
  }
418
421
 
419
422
  } // parallel for
420
- } else if (parallel_mode == 1) {
423
+ } else if (pmode == 1) {
421
424
  std::vector <idx_t> local_idx (k);
422
425
  std::vector <float> local_dis (k);
423
426
 
@@ -460,7 +463,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
460
463
  }
461
464
  } else {
462
465
  FAISS_THROW_FMT ("parallel_mode %d not supported\n",
463
- parallel_mode);
466
+ pmode);
464
467
  }
465
468
  } // parallel section
466
469
 
@@ -608,13 +611,8 @@ InvertedListScanner *IndexIVF::get_InvertedListScanner (
608
611
 
609
612
  void IndexIVF::reconstruct (idx_t key, float* recons) const
610
613
  {
611
- FAISS_THROW_IF_NOT_MSG (direct_map.size() == ntotal,
612
- "direct map is not initialized");
613
- FAISS_THROW_IF_NOT_MSG (key >= 0 && key < direct_map.size(),
614
- "invalid key");
615
- idx_t list_no = direct_map[key] >> 32;
616
- idx_t offset = direct_map[key] & 0xffffffff;
617
- reconstruct_from_offset (list_no, offset, recons);
614
+ idx_t lo = direct_map.get (key);
615
+ reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
618
616
  }
619
617
 
620
618
 
@@ -682,8 +680,8 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
682
680
  // Fill with NaNs
683
681
  memset(reconstructed, -1, sizeof(*reconstructed) * d);
684
682
  } else {
685
- int list_no = key >> 32;
686
- int offset = key & 0xffffffff;
683
+ int list_no = lo_listno (key);
684
+ int offset = lo_offset (key);
687
685
 
688
686
  // Update label to the actual id
689
687
  labels[ij] = invlists->get_single_id (list_no, offset);
@@ -711,42 +709,41 @@ void IndexIVF::reset ()
711
709
 
712
710
  size_t IndexIVF::remove_ids (const IDSelector & sel)
713
711
  {
714
- FAISS_THROW_IF_NOT_MSG (!maintain_direct_map,
715
- "direct map remove not implemented");
716
-
717
- std::vector<idx_t> toremove(nlist);
718
-
719
- #pragma omp parallel for
720
- for (idx_t i = 0; i < nlist; i++) {
721
- idx_t l0 = invlists->list_size (i), l = l0, j = 0;
722
- ScopedIds idsi (invlists, i);
723
- while (j < l) {
724
- if (sel.is_member (idsi[j])) {
725
- l--;
726
- invlists->update_entry (
727
- i, j,
728
- invlists->get_single_id (i, l),
729
- ScopedCodes (invlists, i, l).get());
730
- } else {
731
- j++;
732
- }
733
- }
734
- toremove[i] = l0 - l;
735
- }
736
- // this will not run well in parallel on ondisk because of possible shrinks
737
- size_t nremove = 0;
738
- for (idx_t i = 0; i < nlist; i++) {
739
- if (toremove[i] > 0) {
740
- nremove += toremove[i];
741
- invlists->resize(
742
- i, invlists->list_size(i) - toremove[i]);
743
- }
744
- }
712
+ size_t nremove = direct_map.remove_ids (sel, invlists);
745
713
  ntotal -= nremove;
746
714
  return nremove;
747
715
  }
748
716
 
749
717
 
718
+ void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
719
+ {
720
+
721
+ if (direct_map.type == DirectMap::Hashtable) {
722
+ // just remove then add
723
+ IDSelectorArray sel(n, new_ids);
724
+ size_t nremove = remove_ids (sel);
725
+ FAISS_THROW_IF_NOT_MSG (nremove == n,
726
+ "did not find all entries to remove");
727
+ add_with_ids (n, x, new_ids);
728
+ return;
729
+ }
730
+
731
+ FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array);
732
+ // here it is more tricky because we don't want to introduce holes
733
+ // in continuous range of ids
734
+
735
+ FAISS_THROW_IF_NOT (is_trained);
736
+ std::vector<idx_t> assign (n);
737
+ quantizer->assign (n, x, assign.data());
738
+
739
+ std::vector<uint8_t> flat_codes (n * code_size);
740
+ encode_vectors (n, x, assign.data(), flat_codes.data());
741
+
742
+ direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data());
743
+
744
+ }
745
+
746
+
750
747
 
751
748
 
752
749
  void IndexIVF::train (idx_t n, const float *x)
@@ -779,15 +776,14 @@ void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
779
776
  FAISS_THROW_IF_NOT (other.code_size == code_size);
780
777
  FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
781
778
  "can only merge indexes of the same type");
779
+ FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(),
780
+ "merge direct_map not implemented");
782
781
  }
783
782
 
784
783
 
785
784
  void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
786
785
  {
787
786
  check_compatible_for_merge (other);
788
- FAISS_THROW_IF_NOT_MSG ((!maintain_direct_map &&
789
- !other.maintain_direct_map),
790
- "direct map copy not implemented");
791
787
 
792
788
  invlists->merge_from (other.invlists, add_id);
793
789
 
@@ -817,7 +813,7 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
817
813
 
818
814
  FAISS_THROW_IF_NOT (nlist == other.nlist);
819
815
  FAISS_THROW_IF_NOT (code_size == other.code_size);
820
- FAISS_THROW_IF_NOT (!other.maintain_direct_map);
816
+ FAISS_THROW_IF_NOT (other.direct_map.no());
821
817
  FAISS_THROW_IF_NOT_FMT (
822
818
  subset_type == 0 || subset_type == 1 || subset_type == 2,
823
819
  "subset type %d not implemented", subset_type);