umappp 0.1.6 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +22 -16
- data/ext/umappp/numo.hpp +957 -833
- data/ext/umappp/umappp.cpp +39 -45
- data/lib/umappp/version.rb +1 -1
- data/lib/umappp.rb +5 -4
- data/vendor/aarand/aarand.hpp +141 -28
- data/vendor/annoy/annoylib.h +1 -1
- data/vendor/hnswlib/bruteforce.h +142 -127
- data/vendor/hnswlib/hnswalg.h +1018 -939
- data/vendor/hnswlib/hnswlib.h +149 -58
- data/vendor/hnswlib/space_ip.h +322 -229
- data/vendor/hnswlib/space_l2.h +283 -240
- data/vendor/hnswlib/visited_list_pool.h +54 -55
- data/vendor/irlba/irlba.hpp +12 -27
- data/vendor/irlba/lanczos.hpp +30 -31
- data/vendor/irlba/parallel.hpp +37 -38
- data/vendor/irlba/utils.hpp +12 -23
- data/vendor/irlba/wrappers.hpp +239 -70
- data/vendor/kmeans/Details.hpp +1 -1
- data/vendor/kmeans/HartiganWong.hpp +28 -2
- data/vendor/kmeans/InitializeKmeansPP.hpp +29 -1
- data/vendor/kmeans/Kmeans.hpp +25 -2
- data/vendor/kmeans/Lloyd.hpp +29 -2
- data/vendor/kmeans/MiniBatch.hpp +48 -8
- data/vendor/knncolle/Annoy/Annoy.hpp +3 -0
- data/vendor/knncolle/Hnsw/Hnsw.hpp +3 -0
- data/vendor/knncolle/Kmknn/Kmknn.hpp +11 -1
- data/vendor/knncolle/utils/find_nearest_neighbors.hpp +8 -6
- data/vendor/umappp/Umap.hpp +85 -43
- data/vendor/umappp/optimize_layout.hpp +410 -133
- data/vendor/umappp/spectral_init.hpp +4 -1
- metadata +7 -10
data/vendor/kmeans/MiniBatch.hpp
CHANGED
@@ -53,29 +53,34 @@ public:
|
|
53
53
|
*/
|
54
54
|
struct Defaults {
|
55
55
|
/**
|
56
|
-
* See `
|
56
|
+
* See `set_max_iterations()` for more details.
|
57
57
|
*/
|
58
58
|
static constexpr int max_iterations = 100;
|
59
59
|
|
60
60
|
/**
|
61
|
-
* See `
|
61
|
+
* See `set_batch_size()` for more details.
|
62
62
|
*/
|
63
63
|
static constexpr INDEX_t batch_size = 500;
|
64
64
|
|
65
65
|
/**
|
66
|
-
* See `
|
66
|
+
* See `set_max_change_proportion()` for more details.
|
67
67
|
*/
|
68
68
|
static constexpr double max_change_proportion = 0.01;
|
69
69
|
|
70
70
|
/**
|
71
|
-
* See `
|
71
|
+
* See `set_convergence_history()` for more details.
|
72
72
|
*/
|
73
73
|
static constexpr int convergence_history = 10;
|
74
74
|
|
75
75
|
/**
|
76
|
-
* See `
|
76
|
+
* See `set_seed()` for more details.
|
77
77
|
*/
|
78
78
|
static constexpr uint64_t seed = 1234567890;
|
79
|
+
|
80
|
+
/**
|
81
|
+
* See `set_num_threads()` for more details.
|
82
|
+
*/
|
83
|
+
static constexpr int num_threads = 1;
|
79
84
|
};
|
80
85
|
|
81
86
|
private:
|
@@ -88,6 +93,8 @@ private:
|
|
88
93
|
double max_change = Defaults::max_change_proportion;
|
89
94
|
|
90
95
|
uint64_t seed = Defaults::seed;
|
96
|
+
|
97
|
+
int nthreads = Defaults::num_threads;
|
91
98
|
public:
|
92
99
|
/**
|
93
100
|
* @param i Maximum number of iterations.
|
@@ -143,6 +150,16 @@ public:
|
|
143
150
|
return *this;
|
144
151
|
}
|
145
152
|
|
153
|
+
/**
|
154
|
+
* @param n Number of threads to use.
|
155
|
+
*
|
156
|
+
* @return A reference to this `MiniBatch` object.
|
157
|
+
*/
|
158
|
+
MiniBatch& set_num_threads(int n = Defaults::num_threads) {
|
159
|
+
nthreads = n;
|
160
|
+
return *this;
|
161
|
+
}
|
162
|
+
|
146
163
|
public:
|
147
164
|
/**
|
148
165
|
* @param ndim Number of dimensions.
|
@@ -183,10 +200,22 @@ public:
|
|
183
200
|
}
|
184
201
|
|
185
202
|
QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers);
|
186
|
-
|
187
|
-
|
203
|
+
size_t nchosen = chosen.size();
|
204
|
+
|
205
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
206
|
+
#pragma omp parallel for num_threads(nthreads)
|
207
|
+
for (size_t i = 0; i < nchosen; ++i) {
|
208
|
+
#else
|
209
|
+
KMEANS_CUSTOM_PARALLEL(nchosen, [&](size_t first, size_t last) -> void {
|
210
|
+
for (size_t i = first; i < last; ++i) {
|
211
|
+
#endif
|
188
212
|
clusters[chosen[i]] = index.find(data + chosen[i] * ndim);
|
213
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
214
|
+
}
|
215
|
+
#else
|
189
216
|
}
|
217
|
+
}, nthreads);
|
218
|
+
#endif
|
190
219
|
|
191
220
|
// Updating the means for each cluster.
|
192
221
|
for (auto o : chosen) {
|
@@ -236,10 +265,21 @@ public:
|
|
236
265
|
|
237
266
|
// Run through all observations to make sure they have the latest cluster assignments.
|
238
267
|
QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers);
|
239
|
-
|
268
|
+
|
269
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
270
|
+
#pragma omp parallel for num_threads(nthreads)
|
240
271
|
for (INDEX_t o = 0; o < nobs; ++o) {
|
272
|
+
#else
|
273
|
+
KMEANS_CUSTOM_PARALLEL(nobs, [&](INDEX_t first, INDEX_t last) -> void {
|
274
|
+
for (INDEX_t o = first; o < last; ++o) {
|
275
|
+
#endif
|
241
276
|
clusters[o] = index.find(data + o * ndim);
|
277
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
242
278
|
}
|
279
|
+
#else
|
280
|
+
}
|
281
|
+
}, nthreads);
|
282
|
+
#endif
|
243
283
|
|
244
284
|
std::fill(total_sampled.begin(), total_sampled.end(), 0);
|
245
285
|
for (INDEX_t o = 0; o < nobs; ++o) {
|
@@ -25,6 +25,9 @@ namespace knncolle {
|
|
25
25
|
* For a given query point, each tree is searched to identify the subset of all points in the same leaf node as the query point.
|
26
26
|
* The union of these subsets across all trees is exhaustively searched to identify the actual nearest neighbors to the query.
|
27
27
|
*
|
28
|
+
* Note that, to improve reproducibility across architectures, we have disabled manual vectorization of the distance calculations by default.
|
29
|
+
* This can be restored by defining the `KNNCOLLE_MANUAL_VECTORIZATION` macro.
|
30
|
+
*
|
28
31
|
* @see
|
29
32
|
* Bernhardsson E (2018).
|
30
33
|
* Annoy.
|
@@ -25,6 +25,9 @@ namespace knncolle {
|
|
25
25
|
* The HNSW algorithm extends this idea by using a hierarchy of such graphs containing links of different lengths,
|
26
26
|
* which avoids wasting time on small steps in the early stages of the search where the current node position is far from the query.
|
27
27
|
*
|
28
|
+
* Note that, to improve reproducibility across architectures, we have disabled manual vectorization of the distance calculations by default.
|
29
|
+
* This can be restored by defining the `KNNCOLLE_MANUAL_VECTORIZATION` macro.
|
30
|
+
*
|
28
31
|
* @see
|
29
32
|
* Malkov YA, Yashunin DA (2016).
|
30
33
|
* Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs.
|
@@ -16,6 +16,12 @@
|
|
16
16
|
#include <iostream>
|
17
17
|
#endif
|
18
18
|
|
19
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
20
|
+
#ifdef KNNCOLLE_CUSTOM_PARALLEL
|
21
|
+
#define KMEANS_CUSTOM_PARALLEL KNNCOLLE_CUSTOM_PARALLEL
|
22
|
+
#endif
|
23
|
+
#endif
|
24
|
+
|
19
25
|
/**
|
20
26
|
* @file Kmknn.hpp
|
21
27
|
*
|
@@ -75,11 +81,12 @@ public:
|
|
75
81
|
* i.e., contiguous elements belong to the same observation.
|
76
82
|
* @param power Power of `nobs` to define the number of cluster centers.
|
77
83
|
* By default, a square root is performed.
|
84
|
+
* @param nthreads Number of threads to use for the k-means clustering.
|
78
85
|
*
|
79
86
|
* @tparam INPUT_t Floating-point type of the input data.
|
80
87
|
*/
|
81
88
|
template<typename INPUT_t>
|
82
|
-
Kmknn(INDEX_t ndim, INDEX_t nobs, const INPUT_t* vals, double power = 0.5) :
|
89
|
+
Kmknn(INDEX_t ndim, INDEX_t nobs, const INPUT_t* vals, double power = 0.5, int nthreads = 1) :
|
83
90
|
num_dim(ndim),
|
84
91
|
num_obs(nobs),
|
85
92
|
data(ndim * nobs),
|
@@ -103,6 +110,9 @@ public:
|
|
103
110
|
std::copy(vals, vals + data.size(), data.data());
|
104
111
|
host = data.data();
|
105
112
|
}
|
113
|
+
|
114
|
+
kmeans::Kmeans<INTERNAL_t, int> krunner;
|
115
|
+
krunner.set_num_threads(nthreads);
|
106
116
|
auto output = kmeans::Kmeans<INTERNAL_t, int>().run(ndim, nobs, host, ncenters, centers.data(), clusters.data());
|
107
117
|
std::swap(sizes, output.sizes);
|
108
118
|
|
@@ -36,17 +36,18 @@ using NeighborList = std::vector<std::vector<std::pair<INDEX_t, DISTANCE_t> > >;
|
|
36
36
|
*
|
37
37
|
* @param ptr Pointer to a `Base` index.
|
38
38
|
* @param k Number of nearest neighbors.
|
39
|
+
* @param nthreads Number of threads to use.
|
39
40
|
*
|
40
41
|
* @return A `NeighborList` of length equal to the number of observations in `ptr->nobs()`.
|
41
42
|
* Each entry contains the `k` nearest neighbors for each observation, sorted by increasing distance.
|
42
43
|
*/
|
43
44
|
template<typename INDEX_t = int, typename DISTANCE_t = double, typename InputINDEX_t, typename InputDISTANCE_t, typename InputQUERY_t>
|
44
|
-
NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k) {
|
45
|
+
NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k, int nthreads) {
|
45
46
|
auto n = ptr->nobs();
|
46
47
|
NeighborList<INDEX_t, DISTANCE_t> output(n);
|
47
48
|
|
48
49
|
#ifndef KNNCOLLE_CUSTOM_PARALLEL
|
49
|
-
#pragma omp parallel for
|
50
|
+
#pragma omp parallel for num_threads(nthreads)
|
50
51
|
for (size_t i = 0; i < n; ++i) {
|
51
52
|
#else
|
52
53
|
KNNCOLLE_CUSTOM_PARALLEL(n, [&](size_t first, size_t last) -> void {
|
@@ -62,7 +63,7 @@ NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t
|
|
62
63
|
}
|
63
64
|
}
|
64
65
|
#ifdef KNNCOLLE_CUSTOM_PARALLEL
|
65
|
-
});
|
66
|
+
}, nthreads);
|
66
67
|
#endif
|
67
68
|
|
68
69
|
return output;
|
@@ -79,17 +80,18 @@ NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t
|
|
79
80
|
*
|
80
81
|
* @param ptr Pointer to a `Base` index.
|
81
82
|
* @param k Number of nearest neighbors.
|
83
|
+
* @param nthreads Number of threads to use.
|
82
84
|
*
|
83
85
|
* @return A vector of vectors of length equal to the number of observations in `ptr->nobs()`.
|
84
86
|
* Each vector contains the indices of the `k` nearest neighbors for each observation, sorted by increasing distance.
|
85
87
|
*/
|
86
88
|
template<typename INDEX_t = int, typename InputINDEX_t, typename InputDISTANCE_t, typename InputQUERY_t>
|
87
|
-
std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k) {
|
89
|
+
std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k, int nthreads) {
|
88
90
|
auto n = ptr->nobs();
|
89
91
|
std::vector<std::vector<INDEX_t> > output(n);
|
90
92
|
|
91
93
|
#ifndef KNNCOLLE_CUSTOM_PARALLEL
|
92
|
-
#pragma omp parallel for
|
94
|
+
#pragma omp parallel for num_threads(nthreads)
|
93
95
|
for (size_t i = 0; i < n; ++i) {
|
94
96
|
#else
|
95
97
|
KNNCOLLE_CUSTOM_PARALLEL(n, [&](size_t first, size_t last) -> void {
|
@@ -101,7 +103,7 @@ std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<
|
|
101
103
|
}
|
102
104
|
}
|
103
105
|
#ifdef KNNCOLLE_CUSTOM_PARALLEL
|
104
|
-
});
|
106
|
+
}, nthreads);
|
105
107
|
#endif
|
106
108
|
|
107
109
|
return output;
|
data/vendor/umappp/Umap.hpp
CHANGED
@@ -157,14 +157,14 @@ public:
|
|
157
157
|
static constexpr uint64_t seed = 1234567890;
|
158
158
|
|
159
159
|
/**
|
160
|
-
* See `
|
160
|
+
* See `set_num_threads()`.
|
161
161
|
*/
|
162
|
-
static constexpr
|
162
|
+
static constexpr int num_threads = 1;
|
163
163
|
|
164
164
|
/**
|
165
|
-
* See `
|
165
|
+
* See `set_parallel_optimization()`.
|
166
166
|
*/
|
167
|
-
static constexpr int
|
167
|
+
static constexpr int parallel_optimization = false;
|
168
168
|
};
|
169
169
|
|
170
170
|
private:
|
@@ -184,8 +184,8 @@ private:
|
|
184
184
|
Float b = Defaults::b;
|
185
185
|
Float repulsion_strength = Defaults::repulsion_strength;
|
186
186
|
Float learning_rate = Defaults::learning_rate;
|
187
|
-
bool batch = Defaults::batch;
|
188
187
|
int nthreads = Defaults::num_threads;
|
188
|
+
bool parallel_optimization = Defaults::parallel_optimization;
|
189
189
|
};
|
190
190
|
|
191
191
|
RuntimeParameters rparams;
|
@@ -359,29 +359,13 @@ public:
|
|
359
359
|
return *this;
|
360
360
|
}
|
361
361
|
|
362
|
-
/**
|
363
|
-
* @param b Whether to optimize in batch mode.
|
364
|
-
* Batch mode is required for effective parallelization via OpenMP but may reduce the stability of the gradient descent.
|
365
|
-
*
|
366
|
-
* Batch mode involves computing forces for all observations and applying them simultaneously.
|
367
|
-
* This is in contrast to the default where the location of observation is updated before the forces are computed for the next observation.
|
368
|
-
* As each observation's forces are computed independently, batch mode is more amenable to parallelization;
|
369
|
-
* however, this comes at the cost of stability as the force calculations for later observations are not aware of updates to the positions of earlier observations.
|
370
|
-
*
|
371
|
-
* @return A reference to this `Umap` object.
|
372
|
-
*/
|
373
|
-
Umap& set_batch(bool b = Defaults::batch) {
|
374
|
-
rparams.batch = b;
|
375
|
-
return *this;
|
376
|
-
}
|
377
|
-
|
378
362
|
/**
|
379
363
|
* @param n Number of threads to use.
|
380
364
|
*
|
381
365
|
* @return A reference to this `Umap` object.
|
382
366
|
*
|
383
367
|
* This setting affects nearest neighbor detection (if an existing list of neighbors is not supplied in `initialize()` or `run()`) and spectral initialization.
|
384
|
-
* If `
|
368
|
+
* If `set_parallel_optimization()` is true, it will also affect the layout optimization, i.e., the gradient descent iterations.
|
385
369
|
*
|
386
370
|
* The `UMAPPP_CUSTOM_PARALLEL` macro can be set to a function that specifies a custom parallelization scheme.
|
387
371
|
* This function should be a template that accept three arguments:
|
@@ -404,6 +388,26 @@ public:
|
|
404
388
|
return *this;
|
405
389
|
}
|
406
390
|
|
391
|
+
/**
|
392
|
+
* @param p Whether to enable parallel optimization.
|
393
|
+
* If set to `true`, this will use the number of threads specified in `set_num_threads()` for the layout optimization step.
|
394
|
+
*
|
395
|
+
* @return A reference to this `Umap` object.
|
396
|
+
*
|
397
|
+
* By default, this is set to `false` as the increase in the number of threads is usually not cost-effective for layout optimization.
|
398
|
+
* Specifically, while CPU usage scales with the number of threads, the time spent does not decrease by the same factor.
|
399
|
+
* We also expect that the number of available CPUs is at least equal to the requested number of threads, otherwise contention will greatly degrade performance.
|
400
|
+
* Nonetheless, users can enable parallel optimization if cost is no issue - usually a higher number of threads (above 4) is required to see a reduction in time.
|
401
|
+
*
|
402
|
+
* If the `UMAPPP_NO_PARALLEL_OPTIMIZATION` macro is defined, **umappp** will not be compiled with support for parallel optimization.
|
403
|
+
* This may be desirable in environments that have no support for threading or atomics, or to reduce the binary size if parallelization is not of interest.
|
404
|
+
* In such cases, enabling parallel optimization and calling `Status::run()` will raise an error.
|
405
|
+
*/
|
406
|
+
Umap& set_parallel_optimization(bool p = Defaults::parallel_optimization) {
|
407
|
+
rparams.parallel_optimization = p;
|
408
|
+
return *this;
|
409
|
+
}
|
410
|
+
|
407
411
|
public:
|
408
412
|
/**
|
409
413
|
* @brief Status of the UMAP optimization iterations.
|
@@ -412,15 +416,51 @@ public:
|
|
412
416
|
/**
|
413
417
|
* @cond
|
414
418
|
*/
|
415
|
-
Status(EpochData<Float> e, uint64_t seed, RuntimeParameters p
|
419
|
+
Status(EpochData<Float> e, uint64_t seed, RuntimeParameters p, int n, Float* embed) :
|
420
|
+
epochs(std::move(e)), engine(seed), rparams(std::move(p)), ndim_(n), embedding_(embed) {}
|
416
421
|
|
417
422
|
EpochData<Float> epochs;
|
418
423
|
std::mt19937_64 engine;
|
419
424
|
RuntimeParameters rparams;
|
425
|
+
int ndim_;
|
426
|
+
Float* embedding_;
|
420
427
|
/**
|
421
428
|
* @endcond
|
422
429
|
*/
|
423
430
|
|
431
|
+
/**
|
432
|
+
* @return Number of dimensions of the embedding.
|
433
|
+
*/
|
434
|
+
int ndim() const {
|
435
|
+
return ndim_;
|
436
|
+
}
|
437
|
+
|
438
|
+
/**
|
439
|
+
* @return Pointer to a two-dimensional column-major array where rows are dimensions (`ndim`) and columns are observations.
|
440
|
+
* This is updated by `initialize()` to store the final embedding.
|
441
|
+
*/
|
442
|
+
const Float* embedding() const {
|
443
|
+
return embedding_;
|
444
|
+
}
|
445
|
+
|
446
|
+
/**
|
447
|
+
* @param ptr Pointer to a two-dimensional array as described in `embedding()`.
|
448
|
+
* @param copy Whether the contents of the previous array should be copied into `ptr`.
|
449
|
+
*
|
450
|
+
* By default, the `Status` objects returned by `Umap` methods will operate on embeddings in an array specified at `Status` construction time.
|
451
|
+
* This method will change the embedding array for an existing `Status` object, which can be helpful in some situations,
|
452
|
+
* e.g., to clone a `Status` object and to store its embeddings in a different array than the object.
|
453
|
+
*
|
454
|
+
* Note that the contents of the new array in `ptr` should be the same as the array that it replaces, as `run()` will continue the iteration from the coordinates inside the array.
|
455
|
+
* If a copy was already performed from the old array to the new array, the caller may set `copy = false` to avoid an extra copy.
|
456
|
+
*/
|
457
|
+
void set_embedding(Float* ptr, bool copy = true) {
|
458
|
+
if (copy) {
|
459
|
+
std::copy(embedding_, embedding_ + static_cast<size_t>(ndim()) * nobs(), ptr);
|
460
|
+
}
|
461
|
+
embedding_ = ptr;
|
462
|
+
}
|
463
|
+
|
424
464
|
/**
|
425
465
|
* @return Current epoch.
|
426
466
|
*/
|
@@ -444,21 +484,22 @@ public:
|
|
444
484
|
}
|
445
485
|
|
446
486
|
/**
|
447
|
-
*
|
448
|
-
*
|
449
|
-
* This contains the initial coordinates and is updated to store the final embedding.
|
487
|
+
* The status of the algorithm and the coordinates in `embedding()` are updated to the specified number of epochs.
|
488
|
+
*
|
450
489
|
* @param epoch_limit Number of epochs to run to.
|
451
490
|
* The actual number of epochs performed is equal to the difference between `epoch_limit` and the current number of epochs in `epoch()`.
|
452
|
-
* `epoch_limit` should be not less than `epoch()` and no greater than the maximum number of epochs specified in `Umap::set_num_epochs()`.
|
491
|
+
* `epoch_limit` should be not less than `epoch()` and be no greater than the maximum number of epochs specified in `Umap::set_num_epochs()`.
|
453
492
|
* If zero, defaults to the maximum number of epochs.
|
454
|
-
*
|
455
|
-
* @return The status of the algorithm and the coordinates in `embedding` are updated to the specified number of epochs.
|
456
493
|
*/
|
457
|
-
void run(int
|
458
|
-
if (
|
494
|
+
void run(int epoch_limit = 0) {
|
495
|
+
if (epoch_limit == 0) {
|
496
|
+
epoch_limit = epochs.total_epochs;
|
497
|
+
}
|
498
|
+
|
499
|
+
if (rparams.nthreads == 1 || !rparams.parallel_optimization) {
|
459
500
|
optimize_layout(
|
460
|
-
|
461
|
-
|
501
|
+
ndim_,
|
502
|
+
embedding_,
|
462
503
|
epochs,
|
463
504
|
rparams.a,
|
464
505
|
rparams.b,
|
@@ -468,16 +509,15 @@ public:
|
|
468
509
|
epoch_limit
|
469
510
|
);
|
470
511
|
} else {
|
471
|
-
|
472
|
-
|
473
|
-
|
512
|
+
optimize_layout_parallel(
|
513
|
+
ndim_,
|
514
|
+
embedding_,
|
474
515
|
epochs,
|
475
516
|
rparams.a,
|
476
517
|
rparams.b,
|
477
518
|
rparams.repulsion_strength,
|
478
519
|
rparams.learning_rate,
|
479
|
-
|
480
|
-
[](decltype(engine()) s) -> auto { return std::mt19937_64(s); },
|
520
|
+
engine,
|
481
521
|
epoch_limit,
|
482
522
|
rparams.nthreads
|
483
523
|
);
|
@@ -524,7 +564,9 @@ public:
|
|
524
564
|
return Status(
|
525
565
|
similarities_to_epochs(x, num_epochs_to_do, negative_sample_rate),
|
526
566
|
seed,
|
527
|
-
std::move(pcopy)
|
567
|
+
std::move(pcopy),
|
568
|
+
ndim,
|
569
|
+
embedding
|
528
570
|
);
|
529
571
|
}
|
530
572
|
|
@@ -587,7 +629,7 @@ public:
|
|
587
629
|
*/
|
588
630
|
template<typename Input = Float>
|
589
631
|
Status initialize(int ndim_in, size_t nobs, const Input* input, int ndim_out, Float* embedding) {
|
590
|
-
knncolle::VpTreeEuclidean
|
632
|
+
knncolle::VpTreeEuclidean<int, Input, Input, Input> searcher(ndim_in, nobs, input);
|
591
633
|
return initialize(&searcher, ndim_out, embedding);
|
592
634
|
}
|
593
635
|
#endif
|
@@ -609,7 +651,7 @@ public:
|
|
609
651
|
template<class Algorithm>
|
610
652
|
Status run(const Algorithm* searcher, int ndim, Float* embedding, int epoch_limit = 0) {
|
611
653
|
auto status = initialize(searcher, ndim, embedding);
|
612
|
-
status.run(
|
654
|
+
status.run(epoch_limit);
|
613
655
|
return status;
|
614
656
|
}
|
615
657
|
|
@@ -627,7 +669,7 @@ public:
|
|
627
669
|
*/
|
628
670
|
Status run(NeighborList<Float> x, int ndim, Float* embedding, int epoch_limit = 0) const {
|
629
671
|
auto status = initialize(std::move(x), ndim, embedding);
|
630
|
-
status.run(
|
672
|
+
status.run(epoch_limit);
|
631
673
|
return status;
|
632
674
|
}
|
633
675
|
|
@@ -651,7 +693,7 @@ public:
|
|
651
693
|
template<typename Input = Float>
|
652
694
|
Status run(int ndim_in, size_t nobs, const Input* input, int ndim_out, Float* embedding, int epoch_limit = 0) {
|
653
695
|
auto status = initialize(ndim_in, nobs, input, ndim_out, embedding);
|
654
|
-
status.run(
|
696
|
+
status.run(epoch_limit);
|
655
697
|
return status;
|
656
698
|
}
|
657
699
|
#endif
|