faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -5,17 +5,14 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #pragma once
10
9
 
11
10
  #include <faiss/IndexPQ.h>
12
11
  #include <faiss/impl/ProductQuantizer.h>
13
12
  #include <faiss/utils/AlignedTable.h>
14
13
 
15
-
16
14
  namespace faiss {
17
15
 
18
-
19
16
  /** Fast scan version of IndexPQ. Works for 4-bit PQ for now.
20
17
  *
21
18
  * The codes are not stored sequentially but grouped in blocks of size bbs.
@@ -28,7 +25,7 @@ namespace faiss {
28
25
  * 15: no qbs with reservoir accumulator
29
26
  */
30
27
 
31
- struct IndexPQFastScan: Index {
28
+ struct IndexPQFastScan : Index {
32
29
  ProductQuantizer pq;
33
30
 
34
31
  // implementation to select
@@ -37,8 +34,8 @@ struct IndexPQFastScan: Index {
37
34
  int skip = 0;
38
35
 
39
36
  // size of the kernel
40
- int bbs; // set at build time
41
- int qbs = 0; // query block size 0 = use default
37
+ int bbs; // set at build time
38
+ int qbs = 0; // query block size 0 = use default
42
39
 
43
40
  // packed version of the codes
44
41
  size_t ntotal2;
@@ -47,22 +44,23 @@ struct IndexPQFastScan: Index {
47
44
  AlignedTable<uint8_t> codes;
48
45
 
49
46
  // this is for testing purposes only (set when initialized by IndexPQ)
50
- const uint8_t *orig_codes = nullptr;
47
+ const uint8_t* orig_codes = nullptr;
51
48
 
52
49
  IndexPQFastScan(
53
- int d, size_t M, size_t nbits,
54
- MetricType metric = METRIC_L2,
55
- int bbs = 32
56
- );
50
+ int d,
51
+ size_t M,
52
+ size_t nbits,
53
+ MetricType metric = METRIC_L2,
54
+ int bbs = 32);
57
55
 
58
56
  IndexPQFastScan();
59
57
 
60
58
  /// build from an existing IndexPQ
61
- explicit IndexPQFastScan(const IndexPQ & orig, int bbs = 32);
59
+ explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
62
60
 
63
- void train (idx_t n, const float *x) override;
64
- void add (idx_t n, const float *x) override;
65
- void reset() override ;
61
+ void train(idx_t n, const float* x) override;
62
+ void add(idx_t n, const float* x) override;
63
+ void reset() override;
66
64
  void search(
67
65
  idx_t n,
68
66
  const float* x,
@@ -72,35 +70,51 @@ struct IndexPQFastScan: Index {
72
70
 
73
71
  // called by search function
74
72
  void compute_quantized_LUT(
75
- idx_t n, const float* x,
76
- uint8_t *lut, float *normalizers) const ;
73
+ idx_t n,
74
+ const float* x,
75
+ uint8_t* lut,
76
+ float* normalizers) const;
77
77
 
78
- template<bool is_max>
78
+ template <bool is_max>
79
79
  void search_dispatch_implem(
80
- idx_t n, const float* x, idx_t k,
81
- float* distances, idx_t* labels) const;
80
+ idx_t n,
81
+ const float* x,
82
+ idx_t k,
83
+ float* distances,
84
+ idx_t* labels) const;
82
85
 
83
- template<class C>
86
+ template <class C>
84
87
  void search_implem_2(
85
- idx_t n, const float* x, idx_t k,
86
- float* distances, idx_t* labels) const;
87
-
88
+ idx_t n,
89
+ const float* x,
90
+ idx_t k,
91
+ float* distances,
92
+ idx_t* labels) const;
88
93
 
89
- template<class C>
94
+ template <class C>
90
95
  void search_implem_12(
91
- idx_t n, const float* x, idx_t k,
92
- float* distances, idx_t* labels, int impl) const;
96
+ idx_t n,
97
+ const float* x,
98
+ idx_t k,
99
+ float* distances,
100
+ idx_t* labels,
101
+ int impl) const;
93
102
 
94
- template<class C>
103
+ template <class C>
95
104
  void search_implem_14(
96
- idx_t n, const float* x, idx_t k,
97
- float* distances, idx_t* labels, int impl) const;
98
-
105
+ idx_t n,
106
+ const float* x,
107
+ idx_t k,
108
+ float* distances,
109
+ idx_t* labels,
110
+ int impl) const;
99
111
  };
100
112
 
101
113
  struct FastScanStats {
102
114
  uint64_t t0, t1, t2, t3;
103
- FastScanStats() {reset();}
115
+ FastScanStats() {
116
+ reset();
117
+ }
104
118
  void reset() {
105
119
  memset(this, 0, sizeof(*this));
106
120
  }
@@ -9,13 +9,13 @@
9
9
 
10
10
  #include <faiss/IndexPreTransform.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <cmath>
13
+ #include <cstdio>
14
14
  #include <cstring>
15
15
  #include <memory>
16
16
 
17
- #include <faiss/impl/FaissAssert.h>
18
17
  #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
19
 
20
20
  namespace faiss {
21
21
 
@@ -23,44 +23,29 @@ namespace faiss {
23
23
  * IndexPreTransform
24
24
  *********************************************/
25
25
 
26
- IndexPreTransform::IndexPreTransform ():
27
- index(nullptr), own_fields (false)
28
- {
29
- }
30
-
26
+ IndexPreTransform::IndexPreTransform() : index(nullptr), own_fields(false) {}
31
27
 
32
- IndexPreTransform::IndexPreTransform (
33
- Index * index):
34
- Index (index->d, index->metric_type),
35
- index (index), own_fields (false)
36
- {
28
+ IndexPreTransform::IndexPreTransform(Index* index)
29
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
37
30
  is_trained = index->is_trained;
38
31
  ntotal = index->ntotal;
39
32
  }
40
33
 
41
-
42
- IndexPreTransform::IndexPreTransform (
43
- VectorTransform * ltrans,
44
- Index * index):
45
- Index (index->d, index->metric_type),
46
- index (index), own_fields (false)
47
- {
34
+ IndexPreTransform::IndexPreTransform(VectorTransform* ltrans, Index* index)
35
+ : Index(index->d, index->metric_type), index(index), own_fields(false) {
48
36
  is_trained = index->is_trained;
49
37
  ntotal = index->ntotal;
50
- prepend_transform (ltrans);
38
+ prepend_transform(ltrans);
51
39
  }
52
40
 
53
- void IndexPreTransform::prepend_transform (VectorTransform *ltrans)
54
- {
55
- FAISS_THROW_IF_NOT (ltrans->d_out == d);
41
+ void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
42
+ FAISS_THROW_IF_NOT(ltrans->d_out == d);
56
43
  is_trained = is_trained && ltrans->is_trained;
57
- chain.insert (chain.begin(), ltrans);
44
+ chain.insert(chain.begin(), ltrans);
58
45
  d = ltrans->d_in;
59
46
  }
60
47
 
61
-
62
- IndexPreTransform::~IndexPreTransform ()
63
- {
48
+ IndexPreTransform::~IndexPreTransform() {
64
49
  if (own_fields) {
65
50
  for (int i = 0; i < chain.size(); i++)
66
51
  delete chain[i];
@@ -68,11 +53,7 @@ IndexPreTransform::~IndexPreTransform ()
68
53
  }
69
54
  }
70
55
 
71
-
72
-
73
-
74
- void IndexPreTransform::train (idx_t n, const float *x)
75
- {
56
+ void IndexPreTransform::train(idx_t n, const float* x) {
76
57
  int last_untrained = 0;
77
58
  if (!index->is_trained) {
78
59
  last_untrained = chain.size();
@@ -84,7 +65,7 @@ void IndexPreTransform::train (idx_t n, const float *x)
84
65
  }
85
66
  }
86
67
  }
87
- const float *prev_x = x;
68
+ const float* prev_x = x;
88
69
  ScopeDeleter<float> del;
89
70
 
90
71
  if (verbose) {
@@ -93,34 +74,35 @@ void IndexPreTransform::train (idx_t n, const float *x)
93
74
  }
94
75
 
95
76
  for (int i = 0; i <= last_untrained; i++) {
96
-
97
77
  if (i < chain.size()) {
98
- VectorTransform *ltrans = chain [i];
78
+ VectorTransform* ltrans = chain[i];
99
79
  if (!ltrans->is_trained) {
100
80
  if (verbose) {
101
81
  printf(" Training chain component %d/%zd\n",
102
- i, chain.size());
103
- if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
82
+ i,
83
+ chain.size());
84
+ if (OPQMatrix* opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
104
85
  opqm->verbose = true;
105
86
  }
106
87
  }
107
- ltrans->train (n, prev_x);
88
+ ltrans->train(n, prev_x);
108
89
  }
109
90
  } else {
110
91
  if (verbose) {
111
92
  printf(" Training sub-index\n");
112
93
  }
113
- index->train (n, prev_x);
94
+ index->train(n, prev_x);
114
95
  }
115
- if (i == last_untrained) break;
96
+ if (i == last_untrained)
97
+ break;
116
98
  if (verbose) {
117
- printf(" Applying transform %d/%zd\n",
118
- i, chain.size());
99
+ printf(" Applying transform %d/%zd\n", i, chain.size());
119
100
  }
120
101
 
121
- float * xt = chain[i]->apply (n, prev_x);
102
+ float* xt = chain[i]->apply(n, prev_x);
122
103
 
123
- if (prev_x != x) delete [] prev_x;
104
+ if (prev_x != x)
105
+ delete[] prev_x;
124
106
  prev_x = xt;
125
107
  del.set(xt);
126
108
  }
@@ -128,200 +110,190 @@ void IndexPreTransform::train (idx_t n, const float *x)
128
110
  is_trained = true;
129
111
  }
130
112
 
131
-
132
- const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
133
- {
134
- const float *prev_x = x;
113
+ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
114
+ const float* prev_x = x;
135
115
  ScopeDeleter<float> del;
136
116
 
137
117
  for (int i = 0; i < chain.size(); i++) {
138
- float * xt = chain[i]->apply (n, prev_x);
139
- ScopeDeleter<float> del2 (xt);
140
- del2.swap (del);
118
+ float* xt = chain[i]->apply(n, prev_x);
119
+ ScopeDeleter<float> del2(xt);
120
+ del2.swap(del);
141
121
  prev_x = xt;
142
122
  }
143
- del.release ();
123
+ del.release();
144
124
  return prev_x;
145
125
  }
146
126
 
147
- void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const
148
- {
127
+ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
128
+ const {
149
129
  const float* next_x = xt;
150
130
  ScopeDeleter<float> del;
151
131
 
152
132
  for (int i = chain.size() - 1; i >= 0; i--) {
153
- float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in];
154
- ScopeDeleter<float> del2 ((prev_x == x) ? nullptr : prev_x);
155
- chain [i]->reverse_transform (n, next_x, prev_x);
156
- del2.swap (del);
133
+ float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
134
+ ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
135
+ chain[i]->reverse_transform(n, next_x, prev_x);
136
+ del2.swap(del);
157
137
  next_x = prev_x;
158
138
  }
159
139
  }
160
140
 
161
- void IndexPreTransform::add (idx_t n, const float *x)
162
- {
163
- FAISS_THROW_IF_NOT (is_trained);
164
- const float *xt = apply_chain (n, x);
141
+ void IndexPreTransform::add(idx_t n, const float* x) {
142
+ FAISS_THROW_IF_NOT(is_trained);
143
+ const float* xt = apply_chain(n, x);
165
144
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
166
- index->add (n, xt);
145
+ index->add(n, xt);
167
146
  ntotal = index->ntotal;
168
147
  }
169
148
 
170
- void IndexPreTransform::add_with_ids (idx_t n, const float * x,
171
- const idx_t *xids)
172
- {
173
- FAISS_THROW_IF_NOT (is_trained);
174
- const float *xt = apply_chain (n, x);
149
+ void IndexPreTransform::add_with_ids(
150
+ idx_t n,
151
+ const float* x,
152
+ const idx_t* xids) {
153
+ FAISS_THROW_IF_NOT(is_trained);
154
+ const float* xt = apply_chain(n, x);
175
155
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
176
- index->add_with_ids (n, xt, xids);
156
+ index->add_with_ids(n, xt, xids);
177
157
  ntotal = index->ntotal;
178
158
  }
179
159
 
160
+ void IndexPreTransform::search(
161
+ idx_t n,
162
+ const float* x,
163
+ idx_t k,
164
+ float* distances,
165
+ idx_t* labels) const {
166
+ FAISS_THROW_IF_NOT(k > 0);
180
167
 
181
-
182
-
183
- void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
184
- float *distances, idx_t *labels) const
185
- {
186
- FAISS_THROW_IF_NOT (is_trained);
187
- const float *xt = apply_chain (n, x);
168
+ FAISS_THROW_IF_NOT(is_trained);
169
+ const float* xt = apply_chain(n, x);
188
170
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
189
- index->search (n, xt, k, distances, labels);
171
+ index->search(n, xt, k, distances, labels);
190
172
  }
191
173
 
192
- void IndexPreTransform::range_search (idx_t n, const float* x, float radius,
193
- RangeSearchResult* result) const
194
- {
195
- FAISS_THROW_IF_NOT (is_trained);
196
- const float *xt = apply_chain (n, x);
174
+ void IndexPreTransform::range_search(
175
+ idx_t n,
176
+ const float* x,
177
+ float radius,
178
+ RangeSearchResult* result) const {
179
+ FAISS_THROW_IF_NOT(is_trained);
180
+ const float* xt = apply_chain(n, x);
197
181
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
198
- index->range_search (n, xt, radius, result);
182
+ index->range_search(n, xt, radius, result);
199
183
  }
200
184
 
201
-
202
-
203
- void IndexPreTransform::reset () {
185
+ void IndexPreTransform::reset() {
204
186
  index->reset();
205
187
  ntotal = 0;
206
188
  }
207
189
 
208
- size_t IndexPreTransform::remove_ids (const IDSelector & sel) {
209
- size_t nremove = index->remove_ids (sel);
190
+ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
191
+ size_t nremove = index->remove_ids(sel);
210
192
  ntotal = index->ntotal;
211
193
  return nremove;
212
194
  }
213
195
 
214
-
215
- void IndexPreTransform::reconstruct (idx_t key, float * recons) const
216
- {
217
- float *x = chain.empty() ? recons : new float [index->d];
218
- ScopeDeleter<float> del (recons == x ? nullptr : x);
196
+ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
197
+ float* x = chain.empty() ? recons : new float[index->d];
198
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
219
199
  // Initial reconstruction
220
- index->reconstruct (key, x);
200
+ index->reconstruct(key, x);
221
201
 
222
202
  // Revert transformations from last to first
223
- reverse_chain (1, x, recons);
203
+ reverse_chain(1, x, recons);
224
204
  }
225
205
 
226
-
227
- void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
228
- {
229
- float *x = chain.empty() ? recons : new float [ni * index->d];
230
- ScopeDeleter<float> del (recons == x ? nullptr : x);
206
+ void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
207
+ float* x = chain.empty() ? recons : new float[ni * index->d];
208
+ ScopeDeleter<float> del(recons == x ? nullptr : x);
231
209
  // Initial reconstruction
232
- index->reconstruct_n (i0, ni, x);
210
+ index->reconstruct_n(i0, ni, x);
233
211
 
234
212
  // Revert transformations from last to first
235
- reverse_chain (ni, x, recons);
213
+ reverse_chain(ni, x, recons);
236
214
  }
237
215
 
216
+ void IndexPreTransform::search_and_reconstruct(
217
+ idx_t n,
218
+ const float* x,
219
+ idx_t k,
220
+ float* distances,
221
+ idx_t* labels,
222
+ float* recons) const {
223
+ FAISS_THROW_IF_NOT(k > 0);
238
224
 
239
- void IndexPreTransform::search_and_reconstruct (
240
- idx_t n, const float *x, idx_t k,
241
- float *distances, idx_t *labels, float* recons) const
242
- {
243
- FAISS_THROW_IF_NOT (is_trained);
225
+ FAISS_THROW_IF_NOT(is_trained);
244
226
 
245
- const float* xt = apply_chain (n, x);
246
- ScopeDeleter<float> del ((xt == x) ? nullptr : xt);
227
+ const float* xt = apply_chain(n, x);
228
+ ScopeDeleter<float> del((xt == x) ? nullptr : xt);
247
229
 
248
- float* recons_temp = chain.empty() ? recons : new float [n * k * index->d];
249
- ScopeDeleter<float> del2 ((recons_temp == recons) ? nullptr : recons_temp);
250
- index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp);
230
+ float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
231
+ ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
232
+ index->search_and_reconstruct(n, xt, k, distances, labels, recons_temp);
251
233
 
252
234
  // Revert transformations from last to first
253
- reverse_chain (n * k, recons_temp, recons);
235
+ reverse_chain(n * k, recons_temp, recons);
254
236
  }
255
237
 
256
- size_t IndexPreTransform::sa_code_size () const
257
- {
258
- return index->sa_code_size ();
238
+ size_t IndexPreTransform::sa_code_size() const {
239
+ return index->sa_code_size();
259
240
  }
260
241
 
261
- void IndexPreTransform::sa_encode (idx_t n, const float *x,
262
- uint8_t *bytes) const
263
- {
242
+ void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
243
+ const {
264
244
  if (chain.empty()) {
265
- index->sa_encode (n, x, bytes);
245
+ index->sa_encode(n, x, bytes);
266
246
  } else {
267
- const float *xt = apply_chain (n, x);
247
+ const float* xt = apply_chain(n, x);
268
248
  ScopeDeleter<float> del(xt == x ? nullptr : xt);
269
- index->sa_encode (n, xt, bytes);
249
+ index->sa_encode(n, xt, bytes);
270
250
  }
271
251
  }
272
252
 
273
- void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
274
- float *x) const
275
- {
253
+ void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
254
+ const {
276
255
  if (chain.empty()) {
277
- index->sa_decode (n, bytes, x);
256
+ index->sa_decode(n, bytes, x);
278
257
  } else {
279
- std::unique_ptr<float []> x1 (new float [index->d * n]);
280
- index->sa_decode (n, bytes, x1.get());
258
+ std::unique_ptr<float[]> x1(new float[index->d * n]);
259
+ index->sa_decode(n, bytes, x1.get());
281
260
  // Revert transformations from last to first
282
- reverse_chain (n, x1.get(), x);
261
+ reverse_chain(n, x1.get(), x);
283
262
  }
284
263
  }
285
264
 
286
265
  namespace {
287
266
 
288
- struct PreTransformDistanceComputer: DistanceComputer {
289
- const IndexPreTransform *index;
267
+ struct PreTransformDistanceComputer : DistanceComputer {
268
+ const IndexPreTransform* index;
290
269
  std::unique_ptr<DistanceComputer> sub_dc;
291
- std::unique_ptr<const float []> query;
270
+ std::unique_ptr<const float[]> query;
292
271
 
293
- explicit PreTransformDistanceComputer(const IndexPreTransform *index):
294
- index(index),
295
- sub_dc(index->index->get_distance_computer())
296
- {}
272
+ explicit PreTransformDistanceComputer(const IndexPreTransform* index)
273
+ : index(index), sub_dc(index->index->get_distance_computer()) {}
297
274
 
298
- void set_query(const float *x) override {
299
- const float *xt = index->apply_chain (1, x);
275
+ void set_query(const float* x) override {
276
+ const float* xt = index->apply_chain(1, x);
300
277
  if (xt == x) {
301
- sub_dc->set_query (x);
278
+ sub_dc->set_query(x);
302
279
  } else {
303
280
  query.reset(xt);
304
- sub_dc->set_query (xt);
281
+ sub_dc->set_query(xt);
305
282
  }
306
283
  }
307
284
 
308
- float symmetric_dis(idx_t i, idx_t j) override
309
- {
285
+ float symmetric_dis(idx_t i, idx_t j) override {
310
286
  return sub_dc->symmetric_dis(i, j);
311
287
  }
312
288
 
313
- float operator () (idx_t i) override
314
- {
289
+ float operator()(idx_t i) override {
315
290
  return (*sub_dc)(i);
316
291
  }
317
-
318
292
  };
319
293
 
320
-
321
294
  } // anonymous namespace
322
295
 
323
-
324
- DistanceComputer * IndexPreTransform::get_distance_computer() const {
296
+ DistanceComputer* IndexPreTransform::get_distance_computer() const {
325
297
  if (chain.empty()) {
326
298
  return index->get_distance_computer();
327
299
  } else {
@@ -329,6 +301,4 @@ DistanceComputer * IndexPreTransform::get_distance_computer() const {
329
301
  }
330
302
  }
331
303
 
332
-
333
-
334
304
  } // namespace faiss