umappp 0.1.5 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/README.md +11 -4
- data/ext/umappp/umappp.cpp +41 -43
- data/lib/umappp/version.rb +1 -1
- data/lib/umappp.rb +5 -4
- data/vendor/aarand/aarand.hpp +141 -28
- data/vendor/annoy/annoylib.h +1 -1
- data/vendor/hnswlib/bruteforce.h +142 -127
- data/vendor/hnswlib/hnswalg.h +1018 -939
- data/vendor/hnswlib/hnswlib.h +149 -58
- data/vendor/hnswlib/space_ip.h +322 -229
- data/vendor/hnswlib/space_l2.h +283 -240
- data/vendor/hnswlib/visited_list_pool.h +54 -55
- data/vendor/irlba/irlba.hpp +12 -27
- data/vendor/irlba/lanczos.hpp +30 -31
- data/vendor/irlba/parallel.hpp +37 -38
- data/vendor/irlba/utils.hpp +12 -23
- data/vendor/irlba/wrappers.hpp +239 -70
- data/vendor/kmeans/Details.hpp +1 -1
- data/vendor/kmeans/HartiganWong.hpp +28 -2
- data/vendor/kmeans/InitializeKmeansPP.hpp +29 -1
- data/vendor/kmeans/Kmeans.hpp +25 -2
- data/vendor/kmeans/Lloyd.hpp +29 -2
- data/vendor/kmeans/MiniBatch.hpp +48 -8
- data/vendor/knncolle/Annoy/Annoy.hpp +3 -0
- data/vendor/knncolle/Hnsw/Hnsw.hpp +3 -0
- data/vendor/knncolle/Kmknn/Kmknn.hpp +11 -1
- data/vendor/knncolle/utils/find_nearest_neighbors.hpp +8 -6
- data/vendor/umappp/Umap.hpp +85 -43
- data/vendor/umappp/optimize_layout.hpp +410 -133
- data/vendor/umappp/spectral_init.hpp +4 -1
- metadata +6 -6
data/vendor/kmeans/MiniBatch.hpp
CHANGED
@@ -53,29 +53,34 @@ public:
|
|
53
53
|
*/
|
54
54
|
struct Defaults {
|
55
55
|
/**
|
56
|
-
* See `
|
56
|
+
* See `set_max_iterations()` for more details.
|
57
57
|
*/
|
58
58
|
static constexpr int max_iterations = 100;
|
59
59
|
|
60
60
|
/**
|
61
|
-
* See `
|
61
|
+
* See `set_batch_size()` for more details.
|
62
62
|
*/
|
63
63
|
static constexpr INDEX_t batch_size = 500;
|
64
64
|
|
65
65
|
/**
|
66
|
-
* See `
|
66
|
+
* See `set_max_change_proportion()` for more details.
|
67
67
|
*/
|
68
68
|
static constexpr double max_change_proportion = 0.01;
|
69
69
|
|
70
70
|
/**
|
71
|
-
* See `
|
71
|
+
* See `set_convergence_history()` for more details.
|
72
72
|
*/
|
73
73
|
static constexpr int convergence_history = 10;
|
74
74
|
|
75
75
|
/**
|
76
|
-
* See `
|
76
|
+
* See `set_seed()` for more details.
|
77
77
|
*/
|
78
78
|
static constexpr uint64_t seed = 1234567890;
|
79
|
+
|
80
|
+
/**
|
81
|
+
* See `set_num_threads()` for more details.
|
82
|
+
*/
|
83
|
+
static constexpr int num_threads = 1;
|
79
84
|
};
|
80
85
|
|
81
86
|
private:
|
@@ -88,6 +93,8 @@ private:
|
|
88
93
|
double max_change = Defaults::max_change_proportion;
|
89
94
|
|
90
95
|
uint64_t seed = Defaults::seed;
|
96
|
+
|
97
|
+
int nthreads = Defaults::num_threads;
|
91
98
|
public:
|
92
99
|
/**
|
93
100
|
* @param i Maximum number of iterations.
|
@@ -143,6 +150,16 @@ public:
|
|
143
150
|
return *this;
|
144
151
|
}
|
145
152
|
|
153
|
+
/**
|
154
|
+
* @param n Number of threads to use.
|
155
|
+
*
|
156
|
+
* @return A reference to this `MiniBatch` object.
|
157
|
+
*/
|
158
|
+
MiniBatch& set_num_threads(int n = Defaults::num_threads) {
|
159
|
+
nthreads = n;
|
160
|
+
return *this;
|
161
|
+
}
|
162
|
+
|
146
163
|
public:
|
147
164
|
/**
|
148
165
|
* @param ndim Number of dimensions.
|
@@ -183,10 +200,22 @@ public:
|
|
183
200
|
}
|
184
201
|
|
185
202
|
QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers);
|
186
|
-
|
187
|
-
|
203
|
+
size_t nchosen = chosen.size();
|
204
|
+
|
205
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
206
|
+
#pragma omp parallel for num_threads(nthreads)
|
207
|
+
for (size_t i = 0; i < nchosen; ++i) {
|
208
|
+
#else
|
209
|
+
KMEANS_CUSTOM_PARALLEL(nchosen, [&](size_t first, size_t last) -> void {
|
210
|
+
for (size_t i = first; i < last; ++i) {
|
211
|
+
#endif
|
188
212
|
clusters[chosen[i]] = index.find(data + chosen[i] * ndim);
|
213
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
214
|
+
}
|
215
|
+
#else
|
189
216
|
}
|
217
|
+
}, nthreads);
|
218
|
+
#endif
|
190
219
|
|
191
220
|
// Updating the means for each cluster.
|
192
221
|
for (auto o : chosen) {
|
@@ -236,10 +265,21 @@ public:
|
|
236
265
|
|
237
266
|
// Run through all observations to make sure they have the latest cluster assignments.
|
238
267
|
QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers);
|
239
|
-
|
268
|
+
|
269
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
270
|
+
#pragma omp parallel for num_threads(nthreads)
|
240
271
|
for (INDEX_t o = 0; o < nobs; ++o) {
|
272
|
+
#else
|
273
|
+
KMEANS_CUSTOM_PARALLEL(nobs, [&](INDEX_t first, INDEX_t last) -> void {
|
274
|
+
for (INDEX_t o = first; o < last; ++o) {
|
275
|
+
#endif
|
241
276
|
clusters[o] = index.find(data + o * ndim);
|
277
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
242
278
|
}
|
279
|
+
#else
|
280
|
+
}
|
281
|
+
}, nthreads);
|
282
|
+
#endif
|
243
283
|
|
244
284
|
std::fill(total_sampled.begin(), total_sampled.end(), 0);
|
245
285
|
for (INDEX_t o = 0; o < nobs; ++o) {
|
@@ -25,6 +25,9 @@ namespace knncolle {
|
|
25
25
|
* For a given query point, each tree is searched to identify the subset of all points in the same leaf node as the query point.
|
26
26
|
* The union of these subsets across all trees is exhaustively searched to identify the actual nearest neighbors to the query.
|
27
27
|
*
|
28
|
+
* Note that, to improve reproducibility across architectures, we have disabled manual vectorization of the distance calculations by default.
|
29
|
+
* This can be restored by defining the `KNNCOLLE_MANUAL_VECTORIZATION` macro.
|
30
|
+
*
|
28
31
|
* @see
|
29
32
|
* Bernhardsson E (2018).
|
30
33
|
* Annoy.
|
@@ -25,6 +25,9 @@ namespace knncolle {
|
|
25
25
|
* The HNSW algorithm extends this idea by using a hierarchy of such graphs containing links of different lengths,
|
26
26
|
* which avoids wasting time on small steps in the early stages of the search where the current node position is far from the query.
|
27
27
|
*
|
28
|
+
* Note that, to improve reproducibility across architectures, we have disabled manual vectorization of the distance calculations by default.
|
29
|
+
* This can be restored by defining the `KNNCOLLE_MANUAL_VECTORIZATION` macro.
|
30
|
+
*
|
28
31
|
* @see
|
29
32
|
* Malkov YA, Yashunin DA (2016).
|
30
33
|
* Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs.
|
@@ -16,6 +16,12 @@
|
|
16
16
|
#include <iostream>
|
17
17
|
#endif
|
18
18
|
|
19
|
+
#ifndef KMEANS_CUSTOM_PARALLEL
|
20
|
+
#ifdef KNNCOLLE_CUSTOM_PARALLEL
|
21
|
+
#define KMEANS_CUSTOM_PARALLEL KNNCOLLE_CUSTOM_PARALLEL
|
22
|
+
#endif
|
23
|
+
#endif
|
24
|
+
|
19
25
|
/**
|
20
26
|
* @file Kmknn.hpp
|
21
27
|
*
|
@@ -75,11 +81,12 @@ public:
|
|
75
81
|
* i.e., contiguous elements belong to the same observation.
|
76
82
|
* @param power Power of `nobs` to define the number of cluster centers.
|
77
83
|
* By default, a square root is performed.
|
84
|
+
* @param nthreads Number of threads to use for the k-means clustering.
|
78
85
|
*
|
79
86
|
* @tparam INPUT_t Floating-point type of the input data.
|
80
87
|
*/
|
81
88
|
template<typename INPUT_t>
|
82
|
-
Kmknn(INDEX_t ndim, INDEX_t nobs, const INPUT_t* vals, double power = 0.5) :
|
89
|
+
Kmknn(INDEX_t ndim, INDEX_t nobs, const INPUT_t* vals, double power = 0.5, int nthreads = 1) :
|
83
90
|
num_dim(ndim),
|
84
91
|
num_obs(nobs),
|
85
92
|
data(ndim * nobs),
|
@@ -103,6 +110,9 @@ public:
|
|
103
110
|
std::copy(vals, vals + data.size(), data.data());
|
104
111
|
host = data.data();
|
105
112
|
}
|
113
|
+
|
114
|
+
kmeans::Kmeans<INTERNAL_t, int> krunner;
|
115
|
+
krunner.set_num_threads(nthreads);
|
106
116
|
auto output = kmeans::Kmeans<INTERNAL_t, int>().run(ndim, nobs, host, ncenters, centers.data(), clusters.data());
|
107
117
|
std::swap(sizes, output.sizes);
|
108
118
|
|
@@ -36,17 +36,18 @@ using NeighborList = std::vector<std::vector<std::pair<INDEX_t, DISTANCE_t> > >;
|
|
36
36
|
*
|
37
37
|
* @param ptr Pointer to a `Base` index.
|
38
38
|
* @param k Number of nearest neighbors.
|
39
|
+
* @param nthreads Number of threads to use.
|
39
40
|
*
|
40
41
|
* @return A `NeighborList` of length equal to the number of observations in `ptr->nobs()`.
|
41
42
|
* Each entry contains the `k` nearest neighbors for each observation, sorted by increasing distance.
|
42
43
|
*/
|
43
44
|
template<typename INDEX_t = int, typename DISTANCE_t = double, typename InputINDEX_t, typename InputDISTANCE_t, typename InputQUERY_t>
|
44
|
-
NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k) {
|
45
|
+
NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k, int nthreads) {
|
45
46
|
auto n = ptr->nobs();
|
46
47
|
NeighborList<INDEX_t, DISTANCE_t> output(n);
|
47
48
|
|
48
49
|
#ifndef KNNCOLLE_CUSTOM_PARALLEL
|
49
|
-
#pragma omp parallel for
|
50
|
+
#pragma omp parallel for num_threads(nthreads)
|
50
51
|
for (size_t i = 0; i < n; ++i) {
|
51
52
|
#else
|
52
53
|
KNNCOLLE_CUSTOM_PARALLEL(n, [&](size_t first, size_t last) -> void {
|
@@ -62,7 +63,7 @@ NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t
|
|
62
63
|
}
|
63
64
|
}
|
64
65
|
#ifdef KNNCOLLE_CUSTOM_PARALLEL
|
65
|
-
});
|
66
|
+
}, nthreads);
|
66
67
|
#endif
|
67
68
|
|
68
69
|
return output;
|
@@ -79,17 +80,18 @@ NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t
|
|
79
80
|
*
|
80
81
|
* @param ptr Pointer to a `Base` index.
|
81
82
|
* @param k Number of nearest neighbors.
|
83
|
+
* @param nthreads Number of threads to use.
|
82
84
|
*
|
83
85
|
* @return A vector of vectors of length equal to the number of observations in `ptr->nobs()`.
|
84
86
|
* Each vector contains the indices of the `k` nearest neighbors for each observation, sorted by increasing distance.
|
85
87
|
*/
|
86
88
|
template<typename INDEX_t = int, typename InputINDEX_t, typename InputDISTANCE_t, typename InputQUERY_t>
|
87
|
-
std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k) {
|
89
|
+
std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k, int nthreads) {
|
88
90
|
auto n = ptr->nobs();
|
89
91
|
std::vector<std::vector<INDEX_t> > output(n);
|
90
92
|
|
91
93
|
#ifndef KNNCOLLE_CUSTOM_PARALLEL
|
92
|
-
#pragma omp parallel for
|
94
|
+
#pragma omp parallel for num_threads(nthreads)
|
93
95
|
for (size_t i = 0; i < n; ++i) {
|
94
96
|
#else
|
95
97
|
KNNCOLLE_CUSTOM_PARALLEL(n, [&](size_t first, size_t last) -> void {
|
@@ -101,7 +103,7 @@ std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<
|
|
101
103
|
}
|
102
104
|
}
|
103
105
|
#ifdef KNNCOLLE_CUSTOM_PARALLEL
|
104
|
-
});
|
106
|
+
}, nthreads);
|
105
107
|
#endif
|
106
108
|
|
107
109
|
return output;
|
data/vendor/umappp/Umap.hpp
CHANGED
@@ -157,14 +157,14 @@ public:
|
|
157
157
|
static constexpr uint64_t seed = 1234567890;
|
158
158
|
|
159
159
|
/**
|
160
|
-
* See `
|
160
|
+
* See `set_num_threads()`.
|
161
161
|
*/
|
162
|
-
static constexpr
|
162
|
+
static constexpr int num_threads = 1;
|
163
163
|
|
164
164
|
/**
|
165
|
-
* See `
|
165
|
+
* See `set_parallel_optimization()`.
|
166
166
|
*/
|
167
|
-
static constexpr int
|
167
|
+
static constexpr int parallel_optimization = false;
|
168
168
|
};
|
169
169
|
|
170
170
|
private:
|
@@ -184,8 +184,8 @@ private:
|
|
184
184
|
Float b = Defaults::b;
|
185
185
|
Float repulsion_strength = Defaults::repulsion_strength;
|
186
186
|
Float learning_rate = Defaults::learning_rate;
|
187
|
-
bool batch = Defaults::batch;
|
188
187
|
int nthreads = Defaults::num_threads;
|
188
|
+
bool parallel_optimization = Defaults::parallel_optimization;
|
189
189
|
};
|
190
190
|
|
191
191
|
RuntimeParameters rparams;
|
@@ -359,29 +359,13 @@ public:
|
|
359
359
|
return *this;
|
360
360
|
}
|
361
361
|
|
362
|
-
/**
|
363
|
-
* @param b Whether to optimize in batch mode.
|
364
|
-
* Batch mode is required for effective parallelization via OpenMP but may reduce the stability of the gradient descent.
|
365
|
-
*
|
366
|
-
* Batch mode involves computing forces for all observations and applying them simultaneously.
|
367
|
-
* This is in contrast to the default where the location of observation is updated before the forces are computed for the next observation.
|
368
|
-
* As each observation's forces are computed independently, batch mode is more amenable to parallelization;
|
369
|
-
* however, this comes at the cost of stability as the force calculations for later observations are not aware of updates to the positions of earlier observations.
|
370
|
-
*
|
371
|
-
* @return A reference to this `Umap` object.
|
372
|
-
*/
|
373
|
-
Umap& set_batch(bool b = Defaults::batch) {
|
374
|
-
rparams.batch = b;
|
375
|
-
return *this;
|
376
|
-
}
|
377
|
-
|
378
362
|
/**
|
379
363
|
* @param n Number of threads to use.
|
380
364
|
*
|
381
365
|
* @return A reference to this `Umap` object.
|
382
366
|
*
|
383
367
|
* This setting affects nearest neighbor detection (if an existing list of neighbors is not supplied in `initialize()` or `run()`) and spectral initialization.
|
384
|
-
* If `
|
368
|
+
* If `set_parallel_optimization()` is true, it will also affect the layout optimization, i.e., the gradient descent iterations.
|
385
369
|
*
|
386
370
|
* The `UMAPPP_CUSTOM_PARALLEL` macro can be set to a function that specifies a custom parallelization scheme.
|
387
371
|
* This function should be a template that accept three arguments:
|
@@ -404,6 +388,26 @@ public:
|
|
404
388
|
return *this;
|
405
389
|
}
|
406
390
|
|
391
|
+
/**
|
392
|
+
* @param p Whether to enable parallel optimization.
|
393
|
+
* If set to `true`, this will use the number of threads specified in `set_num_threads()` for the layout optimization step.
|
394
|
+
*
|
395
|
+
* @return A reference to this `Umap` object.
|
396
|
+
*
|
397
|
+
* By default, this is set to `false` as the increase in the number of threads is usually not cost-effective for layout optimization.
|
398
|
+
* Specifically, while CPU usage scales with the number of threads, the time spent does not decrease by the same factor.
|
399
|
+
* We also expect that the number of available CPUs is at least equal to the requested number of threads, otherwise contention will greatly degrade performance.
|
400
|
+
* Nonetheless, users can enable parallel optimization if cost is no issue - usually a higher number of threads (above 4) is required to see a reduction in time.
|
401
|
+
*
|
402
|
+
* If the `UMAPPP_NO_PARALLEL_OPTIMIZATION` macro is defined, **umappp** will not be compiled with support for parallel optimization.
|
403
|
+
* This may be desirable in environments that have no support for threading or atomics, or to reduce the binary size if parallelization is not of interest.
|
404
|
+
* In such cases, enabling parallel optimization and calling `Status::run()` will raise an error.
|
405
|
+
*/
|
406
|
+
Umap& set_parallel_optimization(bool p = Defaults::parallel_optimization) {
|
407
|
+
rparams.parallel_optimization = p;
|
408
|
+
return *this;
|
409
|
+
}
|
410
|
+
|
407
411
|
public:
|
408
412
|
/**
|
409
413
|
* @brief Status of the UMAP optimization iterations.
|
@@ -412,15 +416,51 @@ public:
|
|
412
416
|
/**
|
413
417
|
* @cond
|
414
418
|
*/
|
415
|
-
Status(EpochData<Float> e, uint64_t seed, RuntimeParameters p
|
419
|
+
Status(EpochData<Float> e, uint64_t seed, RuntimeParameters p, int n, Float* embed) :
|
420
|
+
epochs(std::move(e)), engine(seed), rparams(std::move(p)), ndim_(n), embedding_(embed) {}
|
416
421
|
|
417
422
|
EpochData<Float> epochs;
|
418
423
|
std::mt19937_64 engine;
|
419
424
|
RuntimeParameters rparams;
|
425
|
+
int ndim_;
|
426
|
+
Float* embedding_;
|
420
427
|
/**
|
421
428
|
* @endcond
|
422
429
|
*/
|
423
430
|
|
431
|
+
/**
|
432
|
+
* @return Number of dimensions of the embedding.
|
433
|
+
*/
|
434
|
+
int ndim() const {
|
435
|
+
return ndim_;
|
436
|
+
}
|
437
|
+
|
438
|
+
/**
|
439
|
+
* @return Pointer to a two-dimensional column-major array where rows are dimensions (`ndim`) and columns are observations.
|
440
|
+
* This is updated by `initialize()` to store the final embedding.
|
441
|
+
*/
|
442
|
+
const Float* embedding() const {
|
443
|
+
return embedding_;
|
444
|
+
}
|
445
|
+
|
446
|
+
/**
|
447
|
+
* @param ptr Pointer to a two-dimensional array as described in `embedding()`.
|
448
|
+
* @param copy Whether the contents of the previous array should be copied into `ptr`.
|
449
|
+
*
|
450
|
+
* By default, the `Status` objects returned by `Umap` methods will operate on embeddings in an array specified at `Status` construction time.
|
451
|
+
* This method will change the embedding array for an existing `Status` object, which can be helpful in some situations,
|
452
|
+
* e.g., to clone a `Status` object and to store its embeddings in a different array than the object.
|
453
|
+
*
|
454
|
+
* Note that the contents of the new array in `ptr` should be the same as the array that it replaces, as `run()` will continue the iteration from the coordinates inside the array.
|
455
|
+
* If a copy was already performed from the old array to the new array, the caller may set `copy = false` to avoid an extra copy.
|
456
|
+
*/
|
457
|
+
void set_embedding(Float* ptr, bool copy = true) {
|
458
|
+
if (copy) {
|
459
|
+
std::copy(embedding_, embedding_ + static_cast<size_t>(ndim()) * nobs(), ptr);
|
460
|
+
}
|
461
|
+
embedding_ = ptr;
|
462
|
+
}
|
463
|
+
|
424
464
|
/**
|
425
465
|
* @return Current epoch.
|
426
466
|
*/
|
@@ -444,21 +484,22 @@ public:
|
|
444
484
|
}
|
445
485
|
|
446
486
|
/**
|
447
|
-
*
|
448
|
-
*
|
449
|
-
* This contains the initial coordinates and is updated to store the final embedding.
|
487
|
+
* The status of the algorithm and the coordinates in `embedding()` are updated to the specified number of epochs.
|
488
|
+
*
|
450
489
|
* @param epoch_limit Number of epochs to run to.
|
451
490
|
* The actual number of epochs performed is equal to the difference between `epoch_limit` and the current number of epochs in `epoch()`.
|
452
|
-
* `epoch_limit` should be not less than `epoch()` and no greater than the maximum number of epochs specified in `Umap::set_num_epochs()`.
|
491
|
+
* `epoch_limit` should be not less than `epoch()` and be no greater than the maximum number of epochs specified in `Umap::set_num_epochs()`.
|
453
492
|
* If zero, defaults to the maximum number of epochs.
|
454
|
-
*
|
455
|
-
* @return The status of the algorithm and the coordinates in `embedding` are updated to the specified number of epochs.
|
456
493
|
*/
|
457
|
-
void run(int
|
458
|
-
if (
|
494
|
+
void run(int epoch_limit = 0) {
|
495
|
+
if (epoch_limit == 0) {
|
496
|
+
epoch_limit = epochs.total_epochs;
|
497
|
+
}
|
498
|
+
|
499
|
+
if (rparams.nthreads == 1 || !rparams.parallel_optimization) {
|
459
500
|
optimize_layout(
|
460
|
-
|
461
|
-
|
501
|
+
ndim_,
|
502
|
+
embedding_,
|
462
503
|
epochs,
|
463
504
|
rparams.a,
|
464
505
|
rparams.b,
|
@@ -468,16 +509,15 @@ public:
|
|
468
509
|
epoch_limit
|
469
510
|
);
|
470
511
|
} else {
|
471
|
-
|
472
|
-
|
473
|
-
|
512
|
+
optimize_layout_parallel(
|
513
|
+
ndim_,
|
514
|
+
embedding_,
|
474
515
|
epochs,
|
475
516
|
rparams.a,
|
476
517
|
rparams.b,
|
477
518
|
rparams.repulsion_strength,
|
478
519
|
rparams.learning_rate,
|
479
|
-
|
480
|
-
[](decltype(engine()) s) -> auto { return std::mt19937_64(s); },
|
520
|
+
engine,
|
481
521
|
epoch_limit,
|
482
522
|
rparams.nthreads
|
483
523
|
);
|
@@ -524,7 +564,9 @@ public:
|
|
524
564
|
return Status(
|
525
565
|
similarities_to_epochs(x, num_epochs_to_do, negative_sample_rate),
|
526
566
|
seed,
|
527
|
-
std::move(pcopy)
|
567
|
+
std::move(pcopy),
|
568
|
+
ndim,
|
569
|
+
embedding
|
528
570
|
);
|
529
571
|
}
|
530
572
|
|
@@ -587,7 +629,7 @@ public:
|
|
587
629
|
*/
|
588
630
|
template<typename Input = Float>
|
589
631
|
Status initialize(int ndim_in, size_t nobs, const Input* input, int ndim_out, Float* embedding) {
|
590
|
-
knncolle::VpTreeEuclidean
|
632
|
+
knncolle::VpTreeEuclidean<int, Input, Input, Input> searcher(ndim_in, nobs, input);
|
591
633
|
return initialize(&searcher, ndim_out, embedding);
|
592
634
|
}
|
593
635
|
#endif
|
@@ -609,7 +651,7 @@ public:
|
|
609
651
|
template<class Algorithm>
|
610
652
|
Status run(const Algorithm* searcher, int ndim, Float* embedding, int epoch_limit = 0) {
|
611
653
|
auto status = initialize(searcher, ndim, embedding);
|
612
|
-
status.run(
|
654
|
+
status.run(epoch_limit);
|
613
655
|
return status;
|
614
656
|
}
|
615
657
|
|
@@ -627,7 +669,7 @@ public:
|
|
627
669
|
*/
|
628
670
|
Status run(NeighborList<Float> x, int ndim, Float* embedding, int epoch_limit = 0) const {
|
629
671
|
auto status = initialize(std::move(x), ndim, embedding);
|
630
|
-
status.run(
|
672
|
+
status.run(epoch_limit);
|
631
673
|
return status;
|
632
674
|
}
|
633
675
|
|
@@ -651,7 +693,7 @@ public:
|
|
651
693
|
template<typename Input = Float>
|
652
694
|
Status run(int ndim_in, size_t nobs, const Input* input, int ndim_out, Float* embedding, int epoch_limit = 0) {
|
653
695
|
auto status = initialize(ndim_in, nobs, input, ndim_out, embedding);
|
654
|
-
status.run(
|
696
|
+
status.run(epoch_limit);
|
655
697
|
return status;
|
656
698
|
}
|
657
699
|
#endif
|