faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -10,20 +10,19 @@
10
10
  #include <faiss/VectorTransform.h>
11
11
 
12
12
  #include <cinttypes>
13
- #include <cstdio>
14
13
  #include <cmath>
14
+ #include <cstdio>
15
15
  #include <cstring>
16
16
  #include <memory>
17
17
 
18
+ #include <faiss/IndexPQ.h>
19
+ #include <faiss/impl/FaissAssert.h>
18
20
  #include <faiss/utils/distances.h>
19
21
  #include <faiss/utils/random.h>
20
22
  #include <faiss/utils/utils.h>
21
- #include <faiss/impl/FaissAssert.h>
22
- #include <faiss/IndexPQ.h>
23
23
 
24
24
  using namespace faiss;
25
25
 
26
-
27
26
  extern "C" {
28
27
 
29
28
  // this is to keep the clang syntax checker happy
@@ -31,134 +30,183 @@ extern "C" {
31
30
  #define FINTEGER int
32
31
  #endif
33
32
 
34
-
35
33
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
36
34
 
37
- int sgemm_ (
38
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
39
- n, FINTEGER *k, const float *alpha, const float *a,
40
- FINTEGER *lda, const float *b,
41
- FINTEGER *ldb, float *beta,
42
- float *c, FINTEGER *ldc);
43
-
44
- int dgemm_ (
45
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
46
- n, FINTEGER *k, const double *alpha, const double *a,
47
- FINTEGER *lda, const double *b,
48
- FINTEGER *ldb, double *beta,
49
- double *c, FINTEGER *ldc);
50
-
51
- int ssyrk_ (
52
- const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k,
53
- float *alpha, float *a, FINTEGER *lda,
54
- float *beta, float *c, FINTEGER *ldc);
35
+ int sgemm_(
36
+ const char* transa,
37
+ const char* transb,
38
+ FINTEGER* m,
39
+ FINTEGER* n,
40
+ FINTEGER* k,
41
+ const float* alpha,
42
+ const float* a,
43
+ FINTEGER* lda,
44
+ const float* b,
45
+ FINTEGER* ldb,
46
+ float* beta,
47
+ float* c,
48
+ FINTEGER* ldc);
49
+
50
+ int dgemm_(
51
+ const char* transa,
52
+ const char* transb,
53
+ FINTEGER* m,
54
+ FINTEGER* n,
55
+ FINTEGER* k,
56
+ const double* alpha,
57
+ const double* a,
58
+ FINTEGER* lda,
59
+ const double* b,
60
+ FINTEGER* ldb,
61
+ double* beta,
62
+ double* c,
63
+ FINTEGER* ldc);
64
+
65
+ int ssyrk_(
66
+ const char* uplo,
67
+ const char* trans,
68
+ FINTEGER* n,
69
+ FINTEGER* k,
70
+ float* alpha,
71
+ float* a,
72
+ FINTEGER* lda,
73
+ float* beta,
74
+ float* c,
75
+ FINTEGER* ldc);
55
76
 
56
77
  /* Lapack functions from http://www.netlib.org/clapack/old/single/ */
57
78
 
58
- int ssyev_ (
59
- const char *jobz, const char *uplo, FINTEGER *n, float *a,
60
- FINTEGER *lda, float *w, float *work, FINTEGER *lwork,
61
- FINTEGER *info);
62
-
63
- int dsyev_ (
64
- const char *jobz, const char *uplo, FINTEGER *n, double *a,
65
- FINTEGER *lda, double *w, double *work, FINTEGER *lwork,
66
- FINTEGER *info);
79
+ int ssyev_(
80
+ const char* jobz,
81
+ const char* uplo,
82
+ FINTEGER* n,
83
+ float* a,
84
+ FINTEGER* lda,
85
+ float* w,
86
+ float* work,
87
+ FINTEGER* lwork,
88
+ FINTEGER* info);
89
+
90
+ int dsyev_(
91
+ const char* jobz,
92
+ const char* uplo,
93
+ FINTEGER* n,
94
+ double* a,
95
+ FINTEGER* lda,
96
+ double* w,
97
+ double* work,
98
+ FINTEGER* lwork,
99
+ FINTEGER* info);
67
100
 
68
101
  int sgesvd_(
69
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
70
- float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt,
71
- FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info);
72
-
102
+ const char* jobu,
103
+ const char* jobvt,
104
+ FINTEGER* m,
105
+ FINTEGER* n,
106
+ float* a,
107
+ FINTEGER* lda,
108
+ float* s,
109
+ float* u,
110
+ FINTEGER* ldu,
111
+ float* vt,
112
+ FINTEGER* ldvt,
113
+ float* work,
114
+ FINTEGER* lwork,
115
+ FINTEGER* info);
73
116
 
74
117
  int dgesvd_(
75
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
76
- double *a, FINTEGER *lda, double *s, double *u, FINTEGER *ldu, double *vt,
77
- FINTEGER *ldvt, double *work, FINTEGER *lwork, FINTEGER *info);
78
-
118
+ const char* jobu,
119
+ const char* jobvt,
120
+ FINTEGER* m,
121
+ FINTEGER* n,
122
+ double* a,
123
+ FINTEGER* lda,
124
+ double* s,
125
+ double* u,
126
+ FINTEGER* ldu,
127
+ double* vt,
128
+ FINTEGER* ldvt,
129
+ double* work,
130
+ FINTEGER* lwork,
131
+ FINTEGER* info);
79
132
  }
80
133
 
81
134
  /*********************************************
82
135
  * VectorTransform
83
136
  *********************************************/
84
137
 
85
-
86
-
87
- float * VectorTransform::apply (Index::idx_t n, const float * x) const
88
- {
89
- float * xt = new float[n * d_out];
90
- apply_noalloc (n, x, xt);
138
+ float* VectorTransform::apply(Index::idx_t n, const float* x) const {
139
+ float* xt = new float[n * d_out];
140
+ apply_noalloc(n, x, xt);
91
141
  return xt;
92
142
  }
93
143
 
94
-
95
- void VectorTransform::train (idx_t, const float *) {
144
+ void VectorTransform::train(idx_t, const float*) {
96
145
  // does nothing by default
97
146
  }
98
147
 
99
-
100
- void VectorTransform::reverse_transform (
101
- idx_t , const float *,
102
- float *) const
103
- {
104
- FAISS_THROW_MSG ("reverse transform not implemented");
148
+ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
149
+ FAISS_THROW_MSG("reverse transform not implemented");
105
150
  }
106
151
 
107
-
108
-
109
-
110
152
  /*********************************************
111
153
  * LinearTransform
112
154
  *********************************************/
113
155
 
114
156
  /// both d_in > d_out and d_out < d_in are supported
115
- LinearTransform::LinearTransform (int d_in, int d_out,
116
- bool have_bias):
117
- VectorTransform (d_in, d_out), have_bias (have_bias),
118
- is_orthonormal (false), verbose (false)
119
- {
157
+ LinearTransform::LinearTransform(int d_in, int d_out, bool have_bias)
158
+ : VectorTransform(d_in, d_out),
159
+ have_bias(have_bias),
160
+ is_orthonormal(false),
161
+ verbose(false) {
120
162
  is_trained = false; // will be trained when A and b are initialized
121
163
  }
122
164
 
123
- void LinearTransform::apply_noalloc (Index::idx_t n, const float * x,
124
- float * xt) const
125
- {
165
+ void LinearTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
166
+ const {
126
167
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
127
168
 
128
169
  float c_factor;
129
170
  if (have_bias) {
130
- FAISS_THROW_IF_NOT_MSG (b.size() == d_out, "Bias not initialized");
131
- float * xi = xt;
171
+ FAISS_THROW_IF_NOT_MSG(b.size() == d_out, "Bias not initialized");
172
+ float* xi = xt;
132
173
  for (int i = 0; i < n; i++)
133
- for(int j = 0; j < d_out; j++)
174
+ for (int j = 0; j < d_out; j++)
134
175
  *xi++ = b[j];
135
176
  c_factor = 1.0;
136
177
  } else {
137
178
  c_factor = 0.0;
138
179
  }
139
180
 
140
- FAISS_THROW_IF_NOT_MSG (A.size() == d_out * d_in,
141
- "Transformation matrix not initialized");
181
+ FAISS_THROW_IF_NOT_MSG(
182
+ A.size() == d_out * d_in, "Transformation matrix not initialized");
142
183
 
143
184
  float one = 1;
144
185
  FINTEGER nbiti = d_out, ni = n, di = d_in;
145
- sgemm_ ("Transposed", "Not transposed",
146
- &nbiti, &ni, &di,
147
- &one, A.data(), &di, x, &di, &c_factor, xt, &nbiti);
148
-
186
+ sgemm_("Transposed",
187
+ "Not transposed",
188
+ &nbiti,
189
+ &ni,
190
+ &di,
191
+ &one,
192
+ A.data(),
193
+ &di,
194
+ x,
195
+ &di,
196
+ &c_factor,
197
+ xt,
198
+ &nbiti);
149
199
  }
150
200
 
151
-
152
- void LinearTransform::transform_transpose (idx_t n, const float * y,
153
- float *x) const
154
- {
201
+ void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
202
+ const {
155
203
  if (have_bias) { // allocate buffer to store bias-corrected data
156
- float *y_new = new float [n * d_out];
157
- const float *yr = y;
158
- float *yw = y_new;
204
+ float* y_new = new float[n * d_out];
205
+ const float* yr = y;
206
+ float* yw = y_new;
159
207
  for (idx_t i = 0; i < n; i++) {
160
208
  for (int j = 0; j < d_out; j++) {
161
- *yw++ = *yr++ - b [j];
209
+ *yw++ = *yr++ - b[j];
162
210
  }
163
211
  }
164
212
  y = y_new;
@@ -167,15 +215,26 @@ void LinearTransform::transform_transpose (idx_t n, const float * y,
167
215
  {
168
216
  FINTEGER dii = d_in, doi = d_out, ni = n;
169
217
  float one = 1.0, zero = 0.0;
170
- sgemm_ ("Not", "Not", &dii, &ni, &doi,
171
- &one, A.data (), &dii, y, &doi, &zero, x, &dii);
218
+ sgemm_("Not",
219
+ "Not",
220
+ &dii,
221
+ &ni,
222
+ &doi,
223
+ &one,
224
+ A.data(),
225
+ &dii,
226
+ y,
227
+ &doi,
228
+ &zero,
229
+ x,
230
+ &dii);
172
231
  }
173
232
 
174
- if (have_bias) delete [] y;
233
+ if (have_bias)
234
+ delete[] y;
175
235
  }
176
236
 
177
- void LinearTransform::set_is_orthonormal ()
178
- {
237
+ void LinearTransform::set_is_orthonormal() {
179
238
  if (d_out > d_in) {
180
239
  // not clear what we should do in this case
181
240
  is_orthonormal = false;
@@ -193,44 +252,53 @@ void LinearTransform::set_is_orthonormal ()
193
252
  FINTEGER dii = d_in, doi = d_out;
194
253
  float one = 1.0, zero = 0.0;
195
254
 
196
- sgemm_ ("Transposed", "Not", &doi, &doi, &dii,
197
- &one, A.data (), &dii,
198
- A.data(), &dii,
199
- &zero, ATA.data(), &doi);
255
+ sgemm_("Transposed",
256
+ "Not",
257
+ &doi,
258
+ &doi,
259
+ &dii,
260
+ &one,
261
+ A.data(),
262
+ &dii,
263
+ A.data(),
264
+ &dii,
265
+ &zero,
266
+ ATA.data(),
267
+ &doi);
200
268
 
201
269
  is_orthonormal = true;
202
270
  for (long i = 0; i < d_out; i++) {
203
271
  for (long j = 0; j < d_out; j++) {
204
272
  float v = ATA[i + j * d_out];
205
- if (i == j) v-= 1;
273
+ if (i == j)
274
+ v -= 1;
206
275
  if (fabs(v) > eps) {
207
276
  is_orthonormal = false;
208
277
  }
209
278
  }
210
279
  }
211
280
  }
212
-
213
281
  }
214
282
 
215
-
216
- void LinearTransform::reverse_transform (idx_t n, const float * xt,
217
- float *x) const
218
- {
283
+ void LinearTransform::reverse_transform(idx_t n, const float* xt, float* x)
284
+ const {
219
285
  if (is_orthonormal) {
220
- transform_transpose (n, xt, x);
286
+ transform_transpose(n, xt, x);
221
287
  } else {
222
- FAISS_THROW_MSG ("reverse transform not implemented for non-orthonormal matrices");
288
+ FAISS_THROW_MSG(
289
+ "reverse transform not implemented for non-orthonormal matrices");
223
290
  }
224
291
  }
225
292
 
226
-
227
- void LinearTransform::print_if_verbose (
228
- const char*name, const std::vector<double> &mat,
229
- int n, int d) const
230
- {
231
- if (!verbose) return;
293
+ void LinearTransform::print_if_verbose(
294
+ const char* name,
295
+ const std::vector<double>& mat,
296
+ int n,
297
+ int d) const {
298
+ if (!verbose)
299
+ return;
232
300
  printf("matrix %s: %d*%d [\n", name, n, d);
233
- FAISS_THROW_IF_NOT (mat.size() >= n * d);
301
+ FAISS_THROW_IF_NOT(mat.size() >= n * d);
234
302
  for (int i = 0; i < n; i++) {
235
303
  for (int j = 0; j < d; j++) {
236
304
  printf("%10.5g ", mat[i * d + j]);
@@ -244,24 +312,22 @@ void LinearTransform::print_if_verbose (
244
312
  * RandomRotationMatrix
245
313
  *********************************************/
246
314
 
247
- void RandomRotationMatrix::init (int seed)
248
- {
249
-
250
- if(d_out <= d_in) {
251
- A.resize (d_out * d_in);
252
- float *q = A.data();
315
+ void RandomRotationMatrix::init(int seed) {
316
+ if (d_out <= d_in) {
317
+ A.resize(d_out * d_in);
318
+ float* q = A.data();
253
319
  float_randn(q, d_out * d_in, seed);
254
320
  matrix_qr(d_in, d_out, q);
255
321
  } else {
256
322
  // use tight-frame transformation
257
- A.resize (d_out * d_out);
258
- float *q = A.data();
323
+ A.resize(d_out * d_out);
324
+ float* q = A.data();
259
325
  float_randn(q, d_out * d_out, seed);
260
326
  matrix_qr(d_out, d_out, q);
261
327
  // remove columns
262
328
  int i, j;
263
329
  for (i = 0; i < d_out; i++) {
264
- for(j = 0; j < d_in; j++) {
330
+ for (j = 0; j < d_in; j++) {
265
331
  q[i * d_in + j] = q[i * d_out + j];
266
332
  }
267
333
  }
@@ -271,247 +337,281 @@ void RandomRotationMatrix::init (int seed)
271
337
  is_trained = true;
272
338
  }
273
339
 
274
- void RandomRotationMatrix::train (Index::idx_t /*n*/, const float * /*x*/)
275
- {
340
+ void RandomRotationMatrix::train(Index::idx_t /*n*/, const float* /*x*/) {
276
341
  // initialize with some arbitrary seed
277
- init (12345);
342
+ init(12345);
278
343
  }
279
344
 
280
-
281
345
  /*********************************************
282
346
  * PCAMatrix
283
347
  *********************************************/
284
348
 
285
- PCAMatrix::PCAMatrix (int d_in, int d_out,
286
- float eigen_power, bool random_rotation):
287
- LinearTransform(d_in, d_out, true),
288
- eigen_power(eigen_power), random_rotation(random_rotation)
289
- {
349
+ PCAMatrix::PCAMatrix(
350
+ int d_in,
351
+ int d_out,
352
+ float eigen_power,
353
+ bool random_rotation)
354
+ : LinearTransform(d_in, d_out, true),
355
+ eigen_power(eigen_power),
356
+ random_rotation(random_rotation) {
290
357
  is_trained = false;
291
358
  max_points_per_d = 1000;
292
359
  balanced_bins = 0;
360
+ epsilon = 0;
293
361
  }
294
362
 
295
-
296
363
  namespace {
297
364
 
298
365
  /// Compute the eigenvalue decomposition of symmetric matrix cov,
299
366
  /// dimensions d_in-by-d_in. Output eigenvectors in cov.
300
367
 
301
- void eig(size_t d_in, double *cov, double *eigenvalues, int verbose)
302
- {
368
+ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
303
369
  { // compute eigenvalues and vectors
304
370
  FINTEGER info = 0, lwork = -1, di = d_in;
305
371
  double workq;
306
372
 
307
- dsyev_ ("Vectors as well", "Upper",
308
- &di, cov, &di, eigenvalues, &workq, &lwork, &info);
373
+ dsyev_("Vectors as well",
374
+ "Upper",
375
+ &di,
376
+ cov,
377
+ &di,
378
+ eigenvalues,
379
+ &workq,
380
+ &lwork,
381
+ &info);
309
382
  lwork = FINTEGER(workq);
310
- double *work = new double[lwork];
383
+ double* work = new double[lwork];
311
384
 
312
- dsyev_ ("Vectors as well", "Upper",
313
- &di, cov, &di, eigenvalues, work, &lwork, &info);
385
+ dsyev_("Vectors as well",
386
+ "Upper",
387
+ &di,
388
+ cov,
389
+ &di,
390
+ eigenvalues,
391
+ work,
392
+ &lwork,
393
+ &info);
314
394
 
315
- delete [] work;
395
+ delete[] work;
316
396
 
317
397
  if (info != 0) {
318
- fprintf (stderr, "WARN ssyev info returns %d, "
319
- "a very bad PCA matrix is learnt\n",
320
- int(info));
398
+ fprintf(stderr,
399
+ "WARN ssyev info returns %d, "
400
+ "a very bad PCA matrix is learnt\n",
401
+ int(info));
321
402
  // do not throw exception, as the matrix could still be useful
322
403
  }
323
404
 
324
-
325
- if(verbose && d_in <= 10) {
405
+ if (verbose && d_in <= 10) {
326
406
  printf("info=%ld new eigvals=[", long(info));
327
- for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]);
407
+ for (int j = 0; j < d_in; j++)
408
+ printf("%g ", eigenvalues[j]);
328
409
  printf("]\n");
329
410
 
330
- double *ci = cov;
411
+ double* ci = cov;
331
412
  printf("eigenvecs=\n");
332
- for(int i = 0; i < d_in; i++) {
333
- for(int j = 0; j < d_in; j++)
413
+ for (int i = 0; i < d_in; i++) {
414
+ for (int j = 0; j < d_in; j++)
334
415
  printf("%10.4g ", *ci++);
335
416
  printf("\n");
336
417
  }
337
418
  }
338
-
339
419
  }
340
420
 
341
421
  // revert order of eigenvectors & values
342
422
 
343
- for(int i = 0; i < d_in / 2; i++) {
344
-
423
+ for (int i = 0; i < d_in / 2; i++) {
345
424
  std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
346
- double *v1 = cov + i * d_in;
347
- double *v2 = cov + (d_in - 1 - i) * d_in;
348
- for(int j = 0; j < d_in; j++)
425
+ double* v1 = cov + i * d_in;
426
+ double* v2 = cov + (d_in - 1 - i) * d_in;
427
+ for (int j = 0; j < d_in; j++)
349
428
  std::swap(v1[j], v2[j]);
350
429
  }
351
-
352
430
  }
353
431
 
432
+ } // namespace
354
433
 
355
- }
356
-
357
- void PCAMatrix::train (Index::idx_t n, const float *x)
358
- {
359
- const float * x_in = x;
434
+ void PCAMatrix::train(Index::idx_t n, const float* x) {
435
+ const float* x_in = x;
360
436
 
361
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
362
- max_points_per_d * d_in, x, verbose);
437
+ x = fvecs_maybe_subsample(
438
+ d_in, (size_t*)&n, max_points_per_d * d_in, x, verbose);
363
439
 
364
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
440
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
365
441
 
366
442
  // compute mean
367
- mean.clear(); mean.resize(d_in, 0.0);
443
+ mean.clear();
444
+ mean.resize(d_in, 0.0);
368
445
  if (have_bias) { // we may want to skip the bias
369
- const float *xi = x;
446
+ const float* xi = x;
370
447
  for (int i = 0; i < n; i++) {
371
- for(int j = 0; j < d_in; j++)
448
+ for (int j = 0; j < d_in; j++)
372
449
  mean[j] += *xi++;
373
450
  }
374
- for(int j = 0; j < d_in; j++)
451
+ for (int j = 0; j < d_in; j++)
375
452
  mean[j] /= n;
376
453
  }
377
- if(verbose) {
454
+ if (verbose) {
378
455
  printf("mean=[");
379
- for(int j = 0; j < d_in; j++) printf("%g ", mean[j]);
456
+ for (int j = 0; j < d_in; j++)
457
+ printf("%g ", mean[j]);
380
458
  printf("]\n");
381
459
  }
382
460
 
383
- if(n >= d_in) {
461
+ if (n >= d_in) {
384
462
  // compute covariance matrix, store it in PCA matrix
385
463
  PCAMat.resize(d_in * d_in);
386
- float * cov = PCAMat.data();
464
+ float* cov = PCAMat.data();
387
465
  { // initialize with mean * mean^T term
388
- float *ci = cov;
389
- for(int i = 0; i < d_in; i++) {
390
- for(int j = 0; j < d_in; j++)
391
- *ci++ = - n * mean[i] * mean[j];
466
+ float* ci = cov;
467
+ for (int i = 0; i < d_in; i++) {
468
+ for (int j = 0; j < d_in; j++)
469
+ *ci++ = -n * mean[i] * mean[j];
392
470
  }
393
471
  }
394
472
  {
395
473
  FINTEGER di = d_in, ni = n;
396
474
  float one = 1.0;
397
- ssyrk_ ("Up", "Non transposed",
398
- &di, &ni, &one, (float*)x, &di, &one, cov, &di);
399
-
475
+ ssyrk_("Up",
476
+ "Non transposed",
477
+ &di,
478
+ &ni,
479
+ &one,
480
+ (float*)x,
481
+ &di,
482
+ &one,
483
+ cov,
484
+ &di);
400
485
  }
401
- if(verbose && d_in <= 10) {
402
- float *ci = cov;
486
+ if (verbose && d_in <= 10) {
487
+ float* ci = cov;
403
488
  printf("cov=\n");
404
- for(int i = 0; i < d_in; i++) {
405
- for(int j = 0; j < d_in; j++)
489
+ for (int i = 0; i < d_in; i++) {
490
+ for (int j = 0; j < d_in; j++)
406
491
  printf("%10g ", *ci++);
407
492
  printf("\n");
408
493
  }
409
494
  }
410
495
 
411
- std::vector<double> covd (d_in * d_in);
412
- for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i];
496
+ std::vector<double> covd(d_in * d_in);
497
+ for (size_t i = 0; i < d_in * d_in; i++)
498
+ covd[i] = cov[i];
413
499
 
414
- std::vector<double> eigenvaluesd (d_in);
500
+ std::vector<double> eigenvaluesd(d_in);
415
501
 
416
- eig (d_in, covd.data (), eigenvaluesd.data (), verbose);
502
+ eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
417
503
 
418
- for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i];
419
- eigenvalues.resize (d_in);
504
+ for (size_t i = 0; i < d_in * d_in; i++)
505
+ PCAMat[i] = covd[i];
506
+ eigenvalues.resize(d_in);
420
507
 
421
508
  for (size_t i = 0; i < d_in; i++)
422
- eigenvalues [i] = eigenvaluesd [i];
423
-
509
+ eigenvalues[i] = eigenvaluesd[i];
424
510
 
425
511
  } else {
426
-
427
- std::vector<float> xc (n * d_in);
512
+ std::vector<float> xc(n * d_in);
428
513
 
429
514
  for (size_t i = 0; i < n; i++)
430
- for(size_t j = 0; j < d_in; j++)
431
- xc [i * d_in + j] = x [i * d_in + j] - mean[j];
515
+ for (size_t j = 0; j < d_in; j++)
516
+ xc[i * d_in + j] = x[i * d_in + j] - mean[j];
432
517
 
433
518
  // compute Gram matrix
434
- std::vector<float> gram (n * n);
519
+ std::vector<float> gram(n * n);
435
520
  {
436
521
  FINTEGER di = d_in, ni = n;
437
522
  float one = 1.0, zero = 0.0;
438
- ssyrk_ ("Up", "Transposed",
439
- &ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni);
523
+ ssyrk_("Up",
524
+ "Transposed",
525
+ &ni,
526
+ &di,
527
+ &one,
528
+ xc.data(),
529
+ &di,
530
+ &zero,
531
+ gram.data(),
532
+ &ni);
440
533
  }
441
534
 
442
- if(verbose && d_in <= 10) {
443
- float *ci = gram.data();
535
+ if (verbose && d_in <= 10) {
536
+ float* ci = gram.data();
444
537
  printf("gram=\n");
445
- for(int i = 0; i < n; i++) {
446
- for(int j = 0; j < n; j++)
538
+ for (int i = 0; i < n; i++) {
539
+ for (int j = 0; j < n; j++)
447
540
  printf("%10g ", *ci++);
448
541
  printf("\n");
449
542
  }
450
543
  }
451
544
 
452
- std::vector<double> gramd (n * n);
545
+ std::vector<double> gramd(n * n);
453
546
  for (size_t i = 0; i < n * n; i++)
454
- gramd [i] = gram [i];
547
+ gramd[i] = gram[i];
455
548
 
456
- std::vector<double> eigenvaluesd (n);
549
+ std::vector<double> eigenvaluesd(n);
457
550
 
458
551
  // eig will fill in only the n first eigenvals
459
552
 
460
- eig (n, gramd.data (), eigenvaluesd.data (), verbose);
553
+ eig(n, gramd.data(), eigenvaluesd.data(), verbose);
461
554
 
462
555
  PCAMat.resize(d_in * n);
463
556
 
464
557
  for (size_t i = 0; i < n * n; i++)
465
- gram [i] = gramd [i];
558
+ gram[i] = gramd[i];
466
559
 
467
- eigenvalues.resize (d_in);
560
+ eigenvalues.resize(d_in);
468
561
  // fill in only the n first ones
469
562
  for (size_t i = 0; i < n; i++)
470
- eigenvalues [i] = eigenvaluesd [i];
563
+ eigenvalues[i] = eigenvaluesd[i];
471
564
 
472
565
  { // compute PCAMat = x' * v
473
566
  FINTEGER di = d_in, ni = n;
474
567
  float one = 1.0;
475
568
 
476
- sgemm_ ("Non", "Non Trans",
477
- &di, &ni, &ni,
478
- &one, xc.data(), &di, gram.data(), &ni,
479
- &one, PCAMat.data(), &di);
569
+ sgemm_("Non",
570
+ "Non Trans",
571
+ &di,
572
+ &ni,
573
+ &ni,
574
+ &one,
575
+ xc.data(),
576
+ &di,
577
+ gram.data(),
578
+ &ni,
579
+ &one,
580
+ PCAMat.data(),
581
+ &di);
480
582
  }
481
583
 
482
- if(verbose && d_in <= 10) {
483
- float *ci = PCAMat.data();
584
+ if (verbose && d_in <= 10) {
585
+ float* ci = PCAMat.data();
484
586
  printf("PCAMat=\n");
485
- for(int i = 0; i < n; i++) {
486
- for(int j = 0; j < d_in; j++)
587
+ for (int i = 0; i < n; i++) {
588
+ for (int j = 0; j < d_in; j++)
487
589
  printf("%10g ", *ci++);
488
590
  printf("\n");
489
591
  }
490
592
  }
491
- fvec_renorm_L2 (d_in, n, PCAMat.data());
492
-
593
+ fvec_renorm_L2(d_in, n, PCAMat.data());
493
594
  }
494
595
 
495
596
  prepare_Ab();
496
597
  is_trained = true;
497
598
  }
498
599
 
499
- void PCAMatrix::copy_from (const PCAMatrix & other)
500
- {
501
- FAISS_THROW_IF_NOT (other.is_trained);
600
+ void PCAMatrix::copy_from(const PCAMatrix& other) {
601
+ FAISS_THROW_IF_NOT(other.is_trained);
502
602
  mean = other.mean;
503
603
  eigenvalues = other.eigenvalues;
504
604
  PCAMat = other.PCAMat;
505
- prepare_Ab ();
605
+ prepare_Ab();
506
606
  is_trained = true;
507
607
  }
508
608
 
509
- void PCAMatrix::prepare_Ab ()
510
- {
511
- FAISS_THROW_IF_NOT_FMT (
609
+ void PCAMatrix::prepare_Ab() {
610
+ FAISS_THROW_IF_NOT_FMT(
512
611
  d_out * d_in <= PCAMat.size(),
513
612
  "PCA matrix cannot output %d dimensions from %d ",
514
- d_out, d_in);
613
+ d_out,
614
+ d_in);
515
615
 
516
616
  if (!random_rotation) {
517
617
  A = PCAMat;
@@ -519,23 +619,23 @@ void PCAMatrix::prepare_Ab ()
519
619
 
520
620
  // first scale the components
521
621
  if (eigen_power != 0) {
522
- float *ai = A.data();
622
+ float* ai = A.data();
523
623
  for (int i = 0; i < d_out; i++) {
524
- float factor = pow(eigenvalues[i], eigen_power);
525
- for(int j = 0; j < d_in; j++)
624
+ float factor = pow(eigenvalues[i] + epsilon, eigen_power);
625
+ for (int j = 0; j < d_in; j++)
526
626
  *ai++ *= factor;
527
627
  }
528
628
  }
529
629
 
530
630
  if (balanced_bins != 0) {
531
- FAISS_THROW_IF_NOT (d_out % balanced_bins == 0);
631
+ FAISS_THROW_IF_NOT(d_out % balanced_bins == 0);
532
632
  int dsub = d_out / balanced_bins;
533
- std::vector <float> Ain;
633
+ std::vector<float> Ain;
534
634
  std::swap(A, Ain);
535
635
  A.resize(d_out * d_in);
536
636
 
537
- std::vector <float> accu(balanced_bins);
538
- std::vector <int> counter(balanced_bins);
637
+ std::vector<float> accu(balanced_bins);
638
+ std::vector<int> counter(balanced_bins);
539
639
 
540
640
  // greedy assignment
541
641
  for (int i = 0; i < d_out; i++) {
@@ -550,9 +650,8 @@ void PCAMatrix::prepare_Ab ()
550
650
  }
551
651
  int row_dst = best_j * dsub + counter[best_j];
552
652
  accu[best_j] += eigenvalues[i];
553
- counter[best_j] ++;
554
- memcpy (&A[row_dst * d_in], &Ain[i * d_in],
555
- d_in * sizeof (A[0]));
653
+ counter[best_j]++;
654
+ memcpy(&A[row_dst * d_in], &Ain[i * d_in], d_in * sizeof(A[0]));
556
655
  }
557
656
 
558
657
  if (verbose) {
@@ -563,11 +662,11 @@ void PCAMatrix::prepare_Ab ()
563
662
  }
564
663
  }
565
664
 
566
-
567
665
  } else {
568
- FAISS_THROW_IF_NOT_MSG (balanced_bins == 0,
569
- "both balancing bins and applying a random rotation "
570
- "does not make sense");
666
+ FAISS_THROW_IF_NOT_MSG(
667
+ balanced_bins == 0,
668
+ "both balancing bins and applying a random rotation "
669
+ "does not make sense");
571
670
  RandomRotationMatrix rr(d_out, d_out);
572
671
 
573
672
  rr.init(5);
@@ -576,8 +675,8 @@ void PCAMatrix::prepare_Ab ()
576
675
  if (eigen_power != 0) {
577
676
  for (int i = 0; i < d_out; i++) {
578
677
  float factor = pow(eigenvalues[i], eigen_power);
579
- for(int j = 0; j < d_out; j++)
580
- rr.A[j * d_out + i] *= factor;
678
+ for (int j = 0; j < d_out; j++)
679
+ rr.A[j * d_out + i] *= factor;
581
680
  }
582
681
  }
583
682
 
@@ -586,15 +685,24 @@ void PCAMatrix::prepare_Ab ()
586
685
  FINTEGER dii = d_in, doo = d_out;
587
686
  float one = 1.0, zero = 0.0;
588
687
 
589
- sgemm_ ("Not", "Not", &dii, &doo, &doo,
590
- &one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero,
591
- A.data(), &dii);
592
-
688
+ sgemm_("Not",
689
+ "Not",
690
+ &dii,
691
+ &doo,
692
+ &doo,
693
+ &one,
694
+ PCAMat.data(),
695
+ &dii,
696
+ rr.A.data(),
697
+ &doo,
698
+ &zero,
699
+ A.data(),
700
+ &dii);
593
701
  }
594
-
595
702
  }
596
703
 
597
- b.clear(); b.resize(d_out);
704
+ b.clear();
705
+ b.resize(d_out);
598
706
 
599
707
  for (int i = 0; i < d_out; i++) {
600
708
  float accu = 0;
@@ -604,57 +712,61 @@ void PCAMatrix::prepare_Ab ()
604
712
  }
605
713
 
606
714
  is_orthonormal = eigen_power == 0;
607
-
608
715
  }
609
716
 
610
717
  /*********************************************
611
718
  * ITQMatrix
612
719
  *********************************************/
613
720
 
614
- ITQMatrix::ITQMatrix (int d):
615
- LinearTransform(d, d, false),
616
- max_iter (50),
617
- seed (123)
618
- {
619
- }
620
-
721
+ ITQMatrix::ITQMatrix(int d)
722
+ : LinearTransform(d, d, false), max_iter(50), seed(123) {}
621
723
 
622
724
  /** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */
623
- void ITQMatrix::train (Index::idx_t n, const float* xf)
624
- {
725
+ void ITQMatrix::train(Index::idx_t n, const float* xf) {
625
726
  size_t d = d_in;
626
- std::vector<double> rotation (d * d);
727
+ std::vector<double> rotation(d * d);
627
728
 
628
729
  if (init_rotation.size() == d * d) {
629
- memcpy (rotation.data(), init_rotation.data(),
630
- d * d * sizeof(rotation[0]));
730
+ memcpy(rotation.data(),
731
+ init_rotation.data(),
732
+ d * d * sizeof(rotation[0]));
631
733
  } else {
632
- RandomRotationMatrix rrot (d, d);
633
- rrot.init (seed);
734
+ RandomRotationMatrix rrot(d, d);
735
+ rrot.init(seed);
634
736
  for (size_t i = 0; i < d * d; i++) {
635
737
  rotation[i] = rrot.A[i];
636
738
  }
637
739
  }
638
740
 
639
- std::vector<double> x (n * d);
741
+ std::vector<double> x(n * d);
640
742
 
641
743
  for (size_t i = 0; i < n * d; i++) {
642
744
  x[i] = xf[i];
643
745
  }
644
746
 
645
- std::vector<double> rotated_x (n * d), cov_mat (d * d);
646
- std::vector<double> u (d * d), vt (d * d), singvals (d);
747
+ std::vector<double> rotated_x(n * d), cov_mat(d * d);
748
+ std::vector<double> u(d * d), vt(d * d), singvals(d);
647
749
 
648
750
  for (int i = 0; i < max_iter; i++) {
649
- print_if_verbose ("rotation", rotation, d, d);
751
+ print_if_verbose("rotation", rotation, d, d);
650
752
  { // rotated_data = np.dot(training_data, rotation)
651
753
  FINTEGER di = d, ni = n;
652
754
  double one = 1, zero = 0;
653
- dgemm_ ("N", "N", &di, &ni, &di,
654
- &one, rotation.data(), &di, x.data(), &di,
655
- &zero, rotated_x.data(), &di);
755
+ dgemm_("N",
756
+ "N",
757
+ &di,
758
+ &ni,
759
+ &di,
760
+ &one,
761
+ rotation.data(),
762
+ &di,
763
+ x.data(),
764
+ &di,
765
+ &zero,
766
+ rotated_x.data(),
767
+ &di);
656
768
  }
657
- print_if_verbose ("rotated_x", rotated_x, n, d);
769
+ print_if_verbose("rotated_x", rotated_x, n, d);
658
770
  // binarize
659
771
  for (size_t j = 0; j < n * d; j++) {
660
772
  rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
@@ -663,88 +775,119 @@ void ITQMatrix::train (Index::idx_t n, const float* xf)
663
775
  { // rotated_data = np.dot(training_data, rotation)
664
776
  FINTEGER di = d, ni = n;
665
777
  double one = 1, zero = 0;
666
- dgemm_ ("N", "T", &di, &di, &ni,
667
- &one, rotated_x.data(), &di, x.data(), &di,
668
- &zero, cov_mat.data(), &di);
778
+ dgemm_("N",
779
+ "T",
780
+ &di,
781
+ &di,
782
+ &ni,
783
+ &one,
784
+ rotated_x.data(),
785
+ &di,
786
+ x.data(),
787
+ &di,
788
+ &zero,
789
+ cov_mat.data(),
790
+ &di);
669
791
  }
670
- print_if_verbose ("cov_mat", cov_mat, d, d);
792
+ print_if_verbose("cov_mat", cov_mat, d, d);
671
793
  // SVD
672
794
  {
673
-
674
795
  FINTEGER di = d;
675
796
  FINTEGER lwork = -1, info;
676
797
  double lwork1;
677
798
 
678
799
  // workspace query
679
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
680
- singvals.data(), u.data(), &di,
681
- vt.data(), &di,
682
- &lwork1, &lwork, &info);
683
-
684
- FAISS_THROW_IF_NOT (info == 0);
685
- lwork = size_t (lwork1);
686
- std::vector<double> work (lwork);
687
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
688
- singvals.data(), u.data(), &di,
689
- vt.data(), &di,
690
- work.data(), &lwork, &info);
691
- FAISS_THROW_IF_NOT_FMT (info == 0, "sgesvd returned info=%d", info);
692
-
800
+ dgesvd_("A",
801
+ "A",
802
+ &di,
803
+ &di,
804
+ cov_mat.data(),
805
+ &di,
806
+ singvals.data(),
807
+ u.data(),
808
+ &di,
809
+ vt.data(),
810
+ &di,
811
+ &lwork1,
812
+ &lwork,
813
+ &info);
814
+
815
+ FAISS_THROW_IF_NOT(info == 0);
816
+ lwork = size_t(lwork1);
817
+ std::vector<double> work(lwork);
818
+ dgesvd_("A",
819
+ "A",
820
+ &di,
821
+ &di,
822
+ cov_mat.data(),
823
+ &di,
824
+ singvals.data(),
825
+ u.data(),
826
+ &di,
827
+ vt.data(),
828
+ &di,
829
+ work.data(),
830
+ &lwork,
831
+ &info);
832
+ FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
693
833
  }
694
- print_if_verbose ("u", u, d, d);
695
- print_if_verbose ("vt", vt, d, d);
834
+ print_if_verbose("u", u, d, d);
835
+ print_if_verbose("vt", vt, d, d);
696
836
  // update rotation
697
837
  {
698
838
  FINTEGER di = d;
699
839
  double one = 1, zero = 0;
700
- dgemm_ ("N", "T", &di, &di, &di,
701
- &one, u.data(), &di, vt.data(), &di,
702
- &zero, rotation.data(), &di);
840
+ dgemm_("N",
841
+ "T",
842
+ &di,
843
+ &di,
844
+ &di,
845
+ &one,
846
+ u.data(),
847
+ &di,
848
+ vt.data(),
849
+ &di,
850
+ &zero,
851
+ rotation.data(),
852
+ &di);
703
853
  }
704
- print_if_verbose ("final rot", rotation, d, d);
705
-
854
+ print_if_verbose("final rot", rotation, d, d);
706
855
  }
707
- A.resize (d * d);
856
+ A.resize(d * d);
708
857
  for (size_t i = 0; i < d; i++) {
709
858
  for (size_t j = 0; j < d; j++) {
710
859
  A[i + d * j] = rotation[j + d * i];
711
860
  }
712
861
  }
713
862
  is_trained = true;
714
-
715
863
  }
716
864
 
717
- ITQTransform::ITQTransform (int d_in, int d_out, bool do_pca):
718
- VectorTransform (d_in, d_out),
719
- do_pca (do_pca),
720
- itq (d_out),
721
- pca_then_itq (d_in, d_out, false)
722
- {
865
+ ITQTransform::ITQTransform(int d_in, int d_out, bool do_pca)
866
+ : VectorTransform(d_in, d_out),
867
+ do_pca(do_pca),
868
+ itq(d_out),
869
+ pca_then_itq(d_in, d_out, false) {
723
870
  if (!do_pca) {
724
- FAISS_THROW_IF_NOT (d_in == d_out);
871
+ FAISS_THROW_IF_NOT(d_in == d_out);
725
872
  }
726
873
  max_train_per_dim = 10;
727
874
  is_trained = false;
728
875
  }
729
876
 
877
+ void ITQTransform::train(idx_t n, const float* x) {
878
+ FAISS_THROW_IF_NOT(!is_trained);
730
879
 
731
-
732
-
733
- void ITQTransform::train (idx_t n, const float *x)
734
- {
735
- FAISS_THROW_IF_NOT (!is_trained);
736
-
737
- const float * x_in = x;
880
+ const float* x_in = x;
738
881
  size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
739
- x = fvecs_maybe_subsample (d_in, (size_t*)&n, max_train_points, x);
882
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x);
740
883
 
741
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
884
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
742
885
 
743
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
886
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
744
887
  { // normalize
745
888
  int d = d_in;
746
889
 
747
- mean.resize (d, 0);
890
+ mean.resize(d, 0);
748
891
  for (idx_t i = 0; i < n; i++) {
749
892
  for (idx_t j = 0; j < d; j++) {
750
893
  mean[j] += x[i * d + j];
@@ -755,38 +898,47 @@ void ITQTransform::train (idx_t n, const float *x)
755
898
  }
756
899
  for (idx_t i = 0; i < n; i++) {
757
900
  for (idx_t j = 0; j < d; j++) {
758
- x_norm[i * d + j] = x[i * d + j] - mean[j];
901
+ x_norm[i * d + j] = x[i * d + j] - mean[j];
759
902
  }
760
903
  }
761
- fvec_renorm_L2 (d_in, n, x_norm.get());
904
+ fvec_renorm_L2(d_in, n, x_norm.get());
762
905
  }
763
906
 
764
907
  // train PCA
765
908
 
766
- PCAMatrix pca (d_in, d_out);
767
- float *x_pca;
768
- std::unique_ptr<float []> x_pca_del;
909
+ PCAMatrix pca(d_in, d_out);
910
+ float* x_pca;
911
+ std::unique_ptr<float[]> x_pca_del;
769
912
  if (do_pca) {
770
- pca.have_bias = false; // for consistency with reference implem
771
- pca.train (n, x_norm.get());
772
- x_pca = pca.apply (n, x_norm.get());
913
+ pca.have_bias = false; // for consistency with reference implem
914
+ pca.train(n, x_norm.get());
915
+ x_pca = pca.apply(n, x_norm.get());
773
916
  x_pca_del.reset(x_pca);
774
917
  } else {
775
918
  x_pca = x_norm.get();
776
919
  }
777
920
 
778
921
  // train ITQ
779
- itq.train (n, x_pca);
922
+ itq.train(n, x_pca);
780
923
 
781
924
  // merge PCA and ITQ
782
925
  if (do_pca) {
783
926
  FINTEGER di = d_out, dini = d_in;
784
927
  float one = 1, zero = 0;
785
928
  pca_then_itq.A.resize(d_in * d_out);
786
- sgemm_ ("N", "N", &dini, &di, &di,
787
- &one, pca.A.data(), &dini,
788
- itq.A.data(), &di,
789
- &zero, pca_then_itq.A.data(), &dini);
929
+ sgemm_("N",
930
+ "N",
931
+ &dini,
932
+ &di,
933
+ &di,
934
+ &one,
935
+ pca.A.data(),
936
+ &dini,
937
+ itq.A.data(),
938
+ &di,
939
+ &zero,
940
+ pca_then_itq.A.data(),
941
+ &dini);
790
942
  } else {
791
943
  pca_then_itq.A = itq.A;
792
944
  }
@@ -794,12 +946,11 @@ void ITQTransform::train (idx_t n, const float *x)
794
946
  is_trained = true;
795
947
  }
796
948
 
797
- void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
798
- float * xt) const
799
- {
949
+ void ITQTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
950
+ const {
800
951
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
801
952
 
802
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
953
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
803
954
  { // normalize
804
955
  int d = d_in;
805
956
  for (idx_t i = 0; i < n; i++) {
@@ -809,41 +960,36 @@ void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
809
960
  }
810
961
  // this is not really useful if we are going to binarize right
811
962
  // afterwards but OK
812
- fvec_renorm_L2 (d_in, n, x_norm.get());
963
+ fvec_renorm_L2(d_in, n, x_norm.get());
813
964
  }
814
965
 
815
- pca_then_itq.apply_noalloc (n, x_norm.get(), xt);
966
+ pca_then_itq.apply_noalloc(n, x_norm.get(), xt);
816
967
  }
817
968
 
818
969
  /*********************************************
819
970
  * OPQMatrix
820
971
  *********************************************/
821
972
 
822
-
823
- OPQMatrix::OPQMatrix (int d, int M, int d2):
824
- LinearTransform (d, d2 == -1 ? d : d2, false), M(M),
825
- niter (50),
826
- niter_pq (4), niter_pq_0 (40),
827
- verbose(false),
828
- pq(nullptr)
829
- {
973
+ OPQMatrix::OPQMatrix(int d, int M, int d2)
974
+ : LinearTransform(d, d2 == -1 ? d : d2, false),
975
+ M(M),
976
+ niter(50),
977
+ niter_pq(4),
978
+ niter_pq_0(40),
979
+ verbose(false),
980
+ pq(nullptr) {
830
981
  is_trained = false;
831
982
  // OPQ is quite expensive to train, so set this right.
832
983
  max_train_points = 256 * 256;
833
984
  pq = nullptr;
834
985
  }
835
986
 
987
+ void OPQMatrix::train(Index::idx_t n, const float* x) {
988
+ const float* x_in = x;
836
989
 
990
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x, verbose);
837
991
 
838
- void OPQMatrix::train (Index::idx_t n, const float *x)
839
- {
840
-
841
- const float * x_in = x;
842
-
843
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
844
- max_train_points, x, verbose);
845
-
846
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
992
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
847
993
 
848
994
  // To support d_out > d_in, we pad input vectors with 0s to d_out
849
995
  size_t d = d_out <= d_in ? d_in : d_out;
@@ -867,22 +1013,26 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
867
1013
  #endif
868
1014
 
869
1015
  if (verbose) {
870
- printf ("OPQMatrix::train: training an OPQ rotation matrix "
871
- "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
872
- M, n, d_in, d_out);
1016
+ printf("OPQMatrix::train: training an OPQ rotation matrix "
1017
+ "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
1018
+ M,
1019
+ n,
1020
+ d_in,
1021
+ d_out);
873
1022
  }
874
1023
 
875
- std::vector<float> xtrain (n * d);
1024
+ std::vector<float> xtrain(n * d);
876
1025
  // center x
877
1026
  {
878
- std::vector<float> sum (d);
879
- const float *xi = x;
1027
+ std::vector<float> sum(d);
1028
+ const float* xi = x;
880
1029
  for (size_t i = 0; i < n; i++) {
881
1030
  for (int j = 0; j < d_in; j++)
882
- sum [j] += *xi++;
1031
+ sum[j] += *xi++;
883
1032
  }
884
- for (int i = 0; i < d; i++) sum[i] /= n;
885
- float *yi = xtrain.data();
1033
+ for (int i = 0; i < d; i++)
1034
+ sum[i] /= n;
1035
+ float* yi = xtrain.data();
886
1036
  xi = x;
887
1037
  for (size_t i = 0; i < n; i++) {
888
1038
  for (int j = 0; j < d_in; j++)
@@ -890,71 +1040,80 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
890
1040
  yi += d - d_in;
891
1041
  }
892
1042
  }
893
- float *rotation;
1043
+ float* rotation;
894
1044
 
895
- if (A.size () == 0) {
896
- A.resize (d * d);
1045
+ if (A.size() == 0) {
1046
+ A.resize(d * d);
897
1047
  rotation = A.data();
898
1048
  if (verbose)
899
1049
  printf(" OPQMatrix::train: making random %zd*%zd rotation\n",
900
- d, d);
901
- float_randn (rotation, d * d, 1234);
902
- matrix_qr (d, d, rotation);
1050
+ d,
1051
+ d);
1052
+ float_randn(rotation, d * d, 1234);
1053
+ matrix_qr(d, d, rotation);
903
1054
  // we use only the d * d2 upper part of the matrix
904
- A.resize (d * d2);
1055
+ A.resize(d * d2);
905
1056
  } else {
906
- FAISS_THROW_IF_NOT (A.size() == d * d2);
1057
+ FAISS_THROW_IF_NOT(A.size() == d * d2);
907
1058
  rotation = A.data();
908
1059
  }
909
1060
 
910
- std::vector<float>
911
- xproj (d2 * n), pq_recons (d2 * n), xxr (d * n),
912
- tmp(d * d * 4);
913
-
1061
+ std::vector<float> xproj(d2 * n), pq_recons(d2 * n), xxr(d * n),
1062
+ tmp(d * d * 4);
914
1063
 
915
- ProductQuantizer pq_default (d2, M, 8);
916
- ProductQuantizer &pq_regular = pq ? *pq : pq_default;
917
- std::vector<uint8_t> codes (pq_regular.code_size * n);
1064
+ ProductQuantizer pq_default(d2, M, 8);
1065
+ ProductQuantizer& pq_regular = pq ? *pq : pq_default;
1066
+ std::vector<uint8_t> codes(pq_regular.code_size * n);
918
1067
 
919
1068
  double t0 = getmillisecs();
920
1069
  for (int iter = 0; iter < niter; iter++) {
921
-
922
1070
  { // torch.mm(xtrain, rotation:t())
923
1071
  FINTEGER di = d, d2i = d2, ni = n;
924
1072
  float zero = 0, one = 1;
925
- sgemm_ ("Transposed", "Not transposed",
926
- &d2i, &ni, &di,
927
- &one, rotation, &di,
928
- xtrain.data(), &di,
929
- &zero, xproj.data(), &d2i);
1073
+ sgemm_("Transposed",
1074
+ "Not transposed",
1075
+ &d2i,
1076
+ &ni,
1077
+ &di,
1078
+ &one,
1079
+ rotation,
1080
+ &di,
1081
+ xtrain.data(),
1082
+ &di,
1083
+ &zero,
1084
+ xproj.data(),
1085
+ &d2i);
930
1086
  }
931
1087
 
932
1088
  pq_regular.cp.max_points_per_centroid = 1000;
933
1089
  pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
934
1090
  pq_regular.verbose = verbose;
935
- pq_regular.train (n, xproj.data());
1091
+ pq_regular.train(n, xproj.data());
936
1092
 
937
1093
  if (verbose) {
938
1094
  printf(" encode / decode\n");
939
1095
  }
940
1096
  if (pq_regular.assign_index) {
941
- pq_regular.compute_codes_with_assign_index
942
- (xproj.data(), codes.data(), n);
1097
+ pq_regular.compute_codes_with_assign_index(
1098
+ xproj.data(), codes.data(), n);
943
1099
  } else {
944
- pq_regular.compute_codes (xproj.data(), codes.data(), n);
1100
+ pq_regular.compute_codes(xproj.data(), codes.data(), n);
945
1101
  }
946
- pq_regular.decode (codes.data(), pq_recons.data(), n);
1102
+ pq_regular.decode(codes.data(), pq_recons.data(), n);
947
1103
 
948
- float pq_err = fvec_L2sqr (pq_recons.data(), xproj.data(), n * d2) / n;
1104
+ float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
949
1105
 
950
1106
  if (verbose)
951
- printf (" Iteration %d (%d PQ iterations):"
952
- "%.3f s, obj=%g\n", iter, pq_regular.cp.niter,
953
- (getmillisecs () - t0) / 1000.0, pq_err);
1107
+ printf(" Iteration %d (%d PQ iterations):"
1108
+ "%.3f s, obj=%g\n",
1109
+ iter,
1110
+ pq_regular.cp.niter,
1111
+ (getmillisecs() - t0) / 1000.0,
1112
+ pq_err);
954
1113
 
955
1114
  {
956
- float *u = tmp.data(), *vt = &tmp [d * d];
957
- float *sing_val = &tmp [2 * d * d];
1115
+ float *u = tmp.data(), *vt = &tmp[d * d];
1116
+ float* sing_val = &tmp[2 * d * d];
958
1117
  FINTEGER di = d, d2i = d2, ni = n;
959
1118
  float one = 1, zero = 0;
960
1119
 
@@ -962,36 +1121,69 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
962
1121
  printf(" X * recons\n");
963
1122
  }
964
1123
  // torch.mm(xtrain:t(), pq_recons)
965
- sgemm_ ("Not", "Transposed",
966
- &d2i, &di, &ni,
967
- &one, pq_recons.data(), &d2i,
968
- xtrain.data(), &di,
969
- &zero, xxr.data(), &d2i);
970
-
1124
+ sgemm_("Not",
1125
+ "Transposed",
1126
+ &d2i,
1127
+ &di,
1128
+ &ni,
1129
+ &one,
1130
+ pq_recons.data(),
1131
+ &d2i,
1132
+ xtrain.data(),
1133
+ &di,
1134
+ &zero,
1135
+ xxr.data(),
1136
+ &d2i);
971
1137
 
972
1138
  FINTEGER lwork = -1, info = -1;
973
1139
  float worksz;
974
1140
  // workspace query
975
- sgesvd_ ("All", "All",
976
- &d2i, &di, xxr.data(), &d2i,
977
- sing_val,
978
- vt, &d2i, u, &di,
979
- &worksz, &lwork, &info);
1141
+ sgesvd_("All",
1142
+ "All",
1143
+ &d2i,
1144
+ &di,
1145
+ xxr.data(),
1146
+ &d2i,
1147
+ sing_val,
1148
+ vt,
1149
+ &d2i,
1150
+ u,
1151
+ &di,
1152
+ &worksz,
1153
+ &lwork,
1154
+ &info);
980
1155
 
981
1156
  lwork = int(worksz);
982
- std::vector<float> work (lwork);
1157
+ std::vector<float> work(lwork);
983
1158
  // u and vt swapped
984
- sgesvd_ ("All", "All",
985
- &d2i, &di, xxr.data(), &d2i,
986
- sing_val,
987
- vt, &d2i, u, &di,
988
- work.data(), &lwork, &info);
989
-
990
- sgemm_ ("Transposed", "Transposed",
991
- &di, &d2i, &d2i,
992
- &one, u, &di, vt, &d2i,
993
- &zero, rotation, &di);
994
-
1159
+ sgesvd_("All",
1160
+ "All",
1161
+ &d2i,
1162
+ &di,
1163
+ xxr.data(),
1164
+ &d2i,
1165
+ sing_val,
1166
+ vt,
1167
+ &d2i,
1168
+ u,
1169
+ &di,
1170
+ work.data(),
1171
+ &lwork,
1172
+ &info);
1173
+
1174
+ sgemm_("Transposed",
1175
+ "Transposed",
1176
+ &di,
1177
+ &d2i,
1178
+ &d2i,
1179
+ &one,
1180
+ u,
1181
+ &di,
1182
+ vt,
1183
+ &d2i,
1184
+ &zero,
1185
+ rotation,
1186
+ &di);
995
1187
  }
996
1188
  pq_regular.train_type = ProductQuantizer::Train_hot_start;
997
1189
  }
@@ -999,59 +1191,52 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
999
1191
  // revert A matrix
1000
1192
  if (d > d_in) {
1001
1193
  for (long i = 0; i < d_out; i++)
1002
- memmove (&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1003
- A.resize (d_in * d_out);
1194
+ memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1195
+ A.resize(d_in * d_out);
1004
1196
  }
1005
1197
 
1006
1198
  is_trained = true;
1007
1199
  is_orthonormal = true;
1008
1200
  }
1009
1201
 
1010
-
1011
1202
  /*********************************************
1012
1203
  * NormalizationTransform
1013
1204
  *********************************************/
1014
1205
 
1015
- NormalizationTransform::NormalizationTransform (int d, float norm):
1016
- VectorTransform (d, d), norm (norm)
1017
- {
1018
- }
1206
+ NormalizationTransform::NormalizationTransform(int d, float norm)
1207
+ : VectorTransform(d, d), norm(norm) {}
1019
1208
 
1020
- NormalizationTransform::NormalizationTransform ():
1021
- VectorTransform (-1, -1), norm (-1)
1022
- {
1023
- }
1209
+ NormalizationTransform::NormalizationTransform()
1210
+ : VectorTransform(-1, -1), norm(-1) {}
1024
1211
 
1025
- void NormalizationTransform::apply_noalloc
1026
- (idx_t n, const float* x, float* xt) const
1027
- {
1212
+ void NormalizationTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1213
+ const {
1028
1214
  if (norm == 2.0) {
1029
- memcpy (xt, x, sizeof (x[0]) * n * d_in);
1030
- fvec_renorm_L2 (d_in, n, xt);
1215
+ memcpy(xt, x, sizeof(x[0]) * n * d_in);
1216
+ fvec_renorm_L2(d_in, n, xt);
1031
1217
  } else {
1032
- FAISS_THROW_MSG ("not implemented");
1218
+ FAISS_THROW_MSG("not implemented");
1033
1219
  }
1034
1220
  }
1035
1221
 
1036
- void NormalizationTransform::reverse_transform (idx_t n, const float* xt,
1037
- float* x) const
1038
- {
1039
- memcpy (x, xt, sizeof (xt[0]) * n * d_in);
1222
+ void NormalizationTransform::reverse_transform(
1223
+ idx_t n,
1224
+ const float* xt,
1225
+ float* x) const {
1226
+ memcpy(x, xt, sizeof(xt[0]) * n * d_in);
1040
1227
  }
1041
1228
 
1042
1229
  /*********************************************
1043
1230
  * CenteringTransform
1044
1231
  *********************************************/
1045
1232
 
1046
- CenteringTransform::CenteringTransform (int d):
1047
- VectorTransform (d, d)
1048
- {
1233
+ CenteringTransform::CenteringTransform(int d) : VectorTransform(d, d) {
1049
1234
  is_trained = false;
1050
1235
  }
1051
1236
 
1052
- void CenteringTransform::train(Index::idx_t n, const float *x) {
1237
+ void CenteringTransform::train(Index::idx_t n, const float* x) {
1053
1238
  FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
1054
- mean.resize (d_in, 0);
1239
+ mean.resize(d_in, 0);
1055
1240
  for (idx_t i = 0; i < n; i++) {
1056
1241
  for (size_t j = 0; j < d_in; j++) {
1057
1242
  mean[j] += *x++;
@@ -1064,11 +1249,9 @@ void CenteringTransform::train(Index::idx_t n, const float *x) {
1064
1249
  is_trained = true;
1065
1250
  }
1066
1251
 
1067
-
1068
- void CenteringTransform::apply_noalloc
1069
- (idx_t n, const float* x, float* xt) const
1070
- {
1071
- FAISS_THROW_IF_NOT (is_trained);
1252
+ void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1253
+ const {
1254
+ FAISS_THROW_IF_NOT(is_trained);
1072
1255
 
1073
1256
  for (idx_t i = 0; i < n; i++) {
1074
1257
  for (size_t j = 0; j < d_in; j++) {
@@ -1077,64 +1260,58 @@ void CenteringTransform::apply_noalloc
1077
1260
  }
1078
1261
  }
1079
1262
 
1080
- void CenteringTransform::reverse_transform (idx_t n, const float* xt,
1081
- float* x) const
1082
- {
1083
- FAISS_THROW_IF_NOT (is_trained);
1263
+ void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
1264
+ const {
1265
+ FAISS_THROW_IF_NOT(is_trained);
1084
1266
 
1085
1267
  for (idx_t i = 0; i < n; i++) {
1086
1268
  for (size_t j = 0; j < d_in; j++) {
1087
1269
  *x++ = *xt++ + mean[j];
1088
1270
  }
1089
1271
  }
1090
-
1091
1272
  }
1092
1273
 
1093
-
1094
-
1095
-
1096
-
1097
1274
  /*********************************************
1098
1275
  * RemapDimensionsTransform
1099
1276
  *********************************************/
1100
1277
 
1101
-
1102
- RemapDimensionsTransform::RemapDimensionsTransform (
1103
- int d_in, int d_out, const int *map_in):
1104
- VectorTransform (d_in, d_out)
1105
- {
1106
- map.resize (d_out);
1278
+ RemapDimensionsTransform::RemapDimensionsTransform(
1279
+ int d_in,
1280
+ int d_out,
1281
+ const int* map_in)
1282
+ : VectorTransform(d_in, d_out) {
1283
+ map.resize(d_out);
1107
1284
  for (int i = 0; i < d_out; i++) {
1108
1285
  map[i] = map_in[i];
1109
- FAISS_THROW_IF_NOT (map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1286
+ FAISS_THROW_IF_NOT(map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1110
1287
  }
1111
1288
  }
1112
1289
 
1113
- RemapDimensionsTransform::RemapDimensionsTransform (
1114
- int d_in, int d_out, bool uniform): VectorTransform (d_in, d_out)
1115
- {
1116
- map.resize (d_out, -1);
1290
+ RemapDimensionsTransform::RemapDimensionsTransform(
1291
+ int d_in,
1292
+ int d_out,
1293
+ bool uniform)
1294
+ : VectorTransform(d_in, d_out) {
1295
+ map.resize(d_out, -1);
1117
1296
 
1118
1297
  if (uniform) {
1119
1298
  if (d_in < d_out) {
1120
1299
  for (int i = 0; i < d_in; i++) {
1121
- map [i * d_out / d_in] = i;
1122
- }
1300
+ map[i * d_out / d_in] = i;
1301
+ }
1123
1302
  } else {
1124
1303
  for (int i = 0; i < d_out; i++) {
1125
- map [i] = i * d_in / d_out;
1304
+ map[i] = i * d_in / d_out;
1126
1305
  }
1127
1306
  }
1128
1307
  } else {
1129
1308
  for (int i = 0; i < d_in && i < d_out; i++)
1130
- map [i] = i;
1309
+ map[i] = i;
1131
1310
  }
1132
1311
  }
1133
1312
 
1134
-
1135
- void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1136
- float *xt) const
1137
- {
1313
+ void RemapDimensionsTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1314
+ const {
1138
1315
  for (idx_t i = 0; i < n; i++) {
1139
1316
  for (int j = 0; j < d_out; j++) {
1140
1317
  xt[j] = map[j] < 0 ? 0 : x[map[j]];
@@ -1144,13 +1321,15 @@ void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1144
1321
  }
1145
1322
  }
1146
1323
 
1147
- void RemapDimensionsTransform::reverse_transform (idx_t n, const float * xt,
1148
- float *x) const
1149
- {
1150
- memset (x, 0, sizeof (*x) * n * d_in);
1324
+ void RemapDimensionsTransform::reverse_transform(
1325
+ idx_t n,
1326
+ const float* xt,
1327
+ float* x) const {
1328
+ memset(x, 0, sizeof(*x) * n * d_in);
1151
1329
  for (idx_t i = 0; i < n; i++) {
1152
1330
  for (int j = 0; j < d_out; j++) {
1153
- if (map[j] >= 0) x[map[j]] = xt[j];
1331
+ if (map[j] >= 0)
1332
+ x[map[j]] = xt[j];
1154
1333
  }
1155
1334
  x += d_in;
1156
1335
  xt += d_out;