mlpack 4.6.1__cp38-cp38-win_amd64.whl → 4.6.2__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +3 -3
- mlpack/adaboost_classify.cp38-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp38-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp38-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp38-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp38-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp38-win_amd64.pyd +0 -0
- mlpack/cf.cp38-win_amd64.pyd +0 -0
- mlpack/dbscan.cp38-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp38-win_amd64.pyd +0 -0
- mlpack/det.cp38-win_amd64.pyd +0 -0
- mlpack/emst.cp38-win_amd64.pyd +0 -0
- mlpack/fastmks.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp38-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp38-win_amd64.pyd +0 -0
- mlpack/image_converter.cp38-win_amd64.pyd +0 -0
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +21 -12
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +49 -39
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +9 -46
- mlpack/include/mlpack/core/data/save_impl.hpp +315 -315
- mlpack/include/mlpack/core/data/utilities.hpp +158 -158
- mlpack/include/mlpack/core/math/ccov.hpp +1 -0
- mlpack/include/mlpack/core/math/ccov_impl.hpp +4 -5
- mlpack/include/mlpack/core/math/make_alias.hpp +98 -3
- mlpack/include/mlpack/core/util/arma_traits.hpp +19 -2
- mlpack/include/mlpack/core/util/gitversion.hpp +1 -1
- mlpack/include/mlpack/core/util/sfinae_utility.hpp +24 -2
- mlpack/include/mlpack/core/util/version.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -2
- mlpack/include/mlpack/methods/ann/init_rules/network_init.hpp +5 -5
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +3 -2
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +6 -7
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +11 -14
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +5 -4
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +15 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +3 -2
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +14 -15
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +6 -5
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +4 -4
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +3 -2
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +5 -4
- mlpack/include/mlpack/methods/ann/rnn.hpp +19 -18
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +15 -15
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_impl.hpp +3 -8
- mlpack/include/mlpack/methods/decision_tree/fitness_functions/gini_gain.hpp +5 -8
- mlpack/include/mlpack/methods/decision_tree/fitness_functions/information_gain.hpp +5 -8
- mlpack/include/mlpack/methods/gmm/diagonal_gmm_impl.hpp +2 -1
- mlpack/include/mlpack/methods/gmm/eigenvalue_ratio_constraint.hpp +3 -3
- mlpack/include/mlpack/methods/gmm/gmm_impl.hpp +2 -1
- mlpack/include/mlpack/methods/hmm/hmm_impl.hpp +10 -5
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +57 -37
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +69 -59
- mlpack/kde.cp38-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp38-win_amd64.pyd +0 -0
- mlpack/kfn.cp38-win_amd64.pyd +0 -0
- mlpack/kmeans.cp38-win_amd64.pyd +0 -0
- mlpack/knn.cp38-win_amd64.pyd +0 -0
- mlpack/krann.cp38-win_amd64.pyd +0 -0
- mlpack/lars.cp38-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp38-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp38-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp38-win_amd64.pyd +0 -0
- mlpack/lmnn.cp38-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp38-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp38-win_amd64.pyd +0 -0
- mlpack/lsh.cp38-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp38-win_amd64.pyd +0 -0
- mlpack/nbc.cp38-win_amd64.pyd +0 -0
- mlpack/nca.cp38-win_amd64.pyd +0 -0
- mlpack/nmf.cp38-win_amd64.pyd +0 -0
- mlpack/pca.cp38-win_amd64.pyd +0 -0
- mlpack/perceptron.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp38-win_amd64.pyd +0 -0
- mlpack/radical.cp38-win_amd64.pyd +0 -0
- mlpack/random_forest.cp38-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp38-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp38-win_amd64.pyd +0 -0
- mlpack-4.6.2.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/METADATA +2 -2
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/RECORD +102 -102
- mlpack-4.6.1.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/WHEEL +0 -0
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/top_level.txt +0 -0
- mlpack.libs/{.load-order-mlpack-4.6.1 → .load-order-mlpack-4.6.2} +1 -1
|
@@ -1,315 +1,315 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @file core/data/save_impl.hpp
|
|
3
|
-
* @author Ryan Curtin
|
|
4
|
-
* @author Omar Shrit
|
|
5
|
-
*
|
|
6
|
-
* Implementation of save functionality.
|
|
7
|
-
*
|
|
8
|
-
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
-
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
-
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
-
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
-
*/
|
|
13
|
-
#ifndef MLPACK_CORE_DATA_SAVE_IMPL_HPP
|
|
14
|
-
#define MLPACK_CORE_DATA_SAVE_IMPL_HPP
|
|
15
|
-
|
|
16
|
-
// In case it hasn't already been included.
|
|
17
|
-
#include "save.hpp"
|
|
18
|
-
#include "extension.hpp"
|
|
19
|
-
|
|
20
|
-
namespace mlpack {
|
|
21
|
-
namespace data {
|
|
22
|
-
|
|
23
|
-
template<typename eT>
|
|
24
|
-
bool Save(const std::string& filename,
|
|
25
|
-
const arma::Col<eT>& vec,
|
|
26
|
-
const bool fatal,
|
|
27
|
-
FileType inputSaveType)
|
|
28
|
-
{
|
|
29
|
-
// Don't transpose: one observation per line (for CSVs at least).
|
|
30
|
-
return Save(filename, vec, fatal, false, inputSaveType);
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
template<typename eT>
|
|
34
|
-
bool Save(const std::string& filename,
|
|
35
|
-
const arma::Row<eT>& rowvec,
|
|
36
|
-
const bool fatal,
|
|
37
|
-
FileType inputSaveType)
|
|
38
|
-
{
|
|
39
|
-
return Save(filename, rowvec, fatal, true, inputSaveType);
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
// Save a Sparse Matrix
|
|
43
|
-
template<typename eT>
|
|
44
|
-
bool Save(const std::string& filename,
|
|
45
|
-
const arma::SpMat<eT>& matrix,
|
|
46
|
-
const bool fatal,
|
|
47
|
-
bool transpose)
|
|
48
|
-
{
|
|
49
|
-
MatrixOptions opts;
|
|
50
|
-
opts.Fatal() = fatal;
|
|
51
|
-
opts.NoTranspose() = !transpose;
|
|
52
|
-
|
|
53
|
-
return Save(filename, matrix, opts);
|
|
54
|
-
}
|
|
55
|
-
|
|
56
|
-
template<typename eT>
|
|
57
|
-
bool Save(const std::string& filename,
|
|
58
|
-
const arma::Mat<eT>& matrix,
|
|
59
|
-
const bool fatal,
|
|
60
|
-
bool transpose,
|
|
61
|
-
FileType inputSaveType)
|
|
62
|
-
{
|
|
63
|
-
MatrixOptions opts;
|
|
64
|
-
opts.Fatal() = fatal;
|
|
65
|
-
opts.NoTranspose() = !transpose;
|
|
66
|
-
opts.Format() = inputSaveType;
|
|
67
|
-
|
|
68
|
-
return Save(filename, matrix, opts);
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
template<typename MatType, typename DataOptionsType>
|
|
72
|
-
bool Save(const std::string& filename,
|
|
73
|
-
const MatType& matrix,
|
|
74
|
-
const DataOptionsType& opts,
|
|
75
|
-
std::enable_if_t<IsArma<MatType>::value ||
|
|
76
|
-
IsSparseMat<MatType>::value>*)
|
|
77
|
-
{
|
|
78
|
-
//! just use default copy ctor with = operator and make a copy.
|
|
79
|
-
DataOptionsType copyOpts(opts);
|
|
80
|
-
return Save(filename, matrix, copyOpts);
|
|
81
|
-
}
|
|
82
|
-
|
|
83
|
-
/*
|
|
84
|
-
* Add this SFINAE in here because the compiler is so stupid that it is not
|
|
85
|
-
* able to distinguish between these two:
|
|
86
|
-
*
|
|
87
|
-
* data::Save(filename, "model", *output);
|
|
88
|
-
*
|
|
89
|
-
* and
|
|
90
|
-
*
|
|
91
|
-
* data::Save(filename, matrix, opts);
|
|
92
|
-
*
|
|
93
|
-
* The second SFINAE is added because the compiler is bot able to see the
|
|
94
|
-
* difference between:
|
|
95
|
-
*
|
|
96
|
-
* data::Save(filename, Row/Col, fatal);
|
|
97
|
-
*
|
|
98
|
-
* and
|
|
99
|
-
*
|
|
100
|
-
* data::Save(filename, Row/Col, Opts);
|
|
101
|
-
*
|
|
102
|
-
* This SFINAE is temporary and must be removed after the integration of stage 3 or
|
|
103
|
-
* when the compiler becomes more intelligent.
|
|
104
|
-
*/
|
|
105
|
-
template<typename MatType, typename DataOptionsType>
|
|
106
|
-
bool Save(const std::string& filename,
|
|
107
|
-
const MatType& matrix,
|
|
108
|
-
DataOptionsType& opts,
|
|
109
|
-
std::enable_if_t<IsArma<MatType>::value ||
|
|
110
|
-
IsSparseMat<MatType>::value>*,
|
|
111
|
-
std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
|
|
112
|
-
{
|
|
113
|
-
Timer::Start("saving_data");
|
|
114
|
-
|
|
115
|
-
bool success = DetectFileType<MatType>(filename, opts, false);
|
|
116
|
-
if (!success)
|
|
117
|
-
{
|
|
118
|
-
Timer::Stop("saving_data");
|
|
119
|
-
return false;
|
|
120
|
-
}
|
|
121
|
-
|
|
122
|
-
std::fstream stream;
|
|
123
|
-
success = OpenFile(filename, opts, false, stream);
|
|
124
|
-
if (!success)
|
|
125
|
-
{
|
|
126
|
-
Timer::Stop("saving_data");
|
|
127
|
-
return false;
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
// Try to save the file.
|
|
131
|
-
Log::Info << "Saving " << opts.FileTypeToString() << " to '" << filename
|
|
132
|
-
<< "'." << std::endl;
|
|
133
|
-
if constexpr (IsArma<MatType>::value || IsSparseMat<MatType>::value)
|
|
134
|
-
{
|
|
135
|
-
TextOptions txtOpts(std::move(opts));
|
|
136
|
-
if constexpr (IsSparseMat<MatType>::value)
|
|
137
|
-
{
|
|
138
|
-
success = SaveSparse(matrix, txtOpts, filename, stream);
|
|
139
|
-
}
|
|
140
|
-
else if constexpr (IsCol<MatType>::value)
|
|
141
|
-
{
|
|
142
|
-
opts.NoTranspose() = true;
|
|
143
|
-
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
144
|
-
}
|
|
145
|
-
else if constexpr (IsRow<MatType>::value)
|
|
146
|
-
{
|
|
147
|
-
opts.NoTranspose() = false;
|
|
148
|
-
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
149
|
-
}
|
|
150
|
-
else if constexpr (IsDense<MatType>::value)
|
|
151
|
-
{
|
|
152
|
-
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
153
|
-
}
|
|
154
|
-
opts = std::move(txtOpts);
|
|
155
|
-
}
|
|
156
|
-
else
|
|
157
|
-
{
|
|
158
|
-
if (opts.Fatal())
|
|
159
|
-
Log::Fatal << "DataOptionsType is unknown! Please use a known type or "
|
|
160
|
-
<< "or provide specific overloads." << std::endl;
|
|
161
|
-
else
|
|
162
|
-
Log::Warn << "DataOptionsType is unknown! Please use a known type or "
|
|
163
|
-
<< "or provide specific overloads." << std::endl;
|
|
164
|
-
|
|
165
|
-
return false;
|
|
166
|
-
}
|
|
167
|
-
|
|
168
|
-
if (!success)
|
|
169
|
-
{
|
|
170
|
-
Timer::Stop("saving_data");
|
|
171
|
-
if (opts.Fatal())
|
|
172
|
-
Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
|
|
173
|
-
else
|
|
174
|
-
Log::Warn << "Save to '" << filename << "' failed." << std::endl;
|
|
175
|
-
return false;
|
|
176
|
-
}
|
|
177
|
-
|
|
178
|
-
Timer::Stop("saving_data");
|
|
179
|
-
|
|
180
|
-
return success;
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
template<typename eT>
|
|
184
|
-
bool SaveDense(const arma::Mat<eT>& matrix,
|
|
185
|
-
TextOptions& opts,
|
|
186
|
-
const std::string& filename,
|
|
187
|
-
std::fstream& stream)
|
|
188
|
-
{
|
|
189
|
-
bool success = false;
|
|
190
|
-
arma::Mat<eT> tmp;
|
|
191
|
-
// Transpose the matrix.
|
|
192
|
-
if (!opts.NoTranspose())
|
|
193
|
-
{
|
|
194
|
-
tmp = trans(matrix);
|
|
195
|
-
success = SaveMatrix(tmp, opts, filename, stream);
|
|
196
|
-
}
|
|
197
|
-
else
|
|
198
|
-
success = SaveMatrix(matrix, opts, filename, stream);
|
|
199
|
-
|
|
200
|
-
return success;
|
|
201
|
-
}
|
|
202
|
-
|
|
203
|
-
// Save a Sparse Matrix
|
|
204
|
-
template<typename eT>
|
|
205
|
-
bool SaveSparse(const arma::SpMat<eT>& matrix,
|
|
206
|
-
TextOptions& opts,
|
|
207
|
-
const std::string& filename,
|
|
208
|
-
std::fstream& stream)
|
|
209
|
-
{
|
|
210
|
-
bool success = false;
|
|
211
|
-
arma::SpMat<eT> tmp;
|
|
212
|
-
|
|
213
|
-
// Transpose the matrix.
|
|
214
|
-
if (!opts.NoTranspose())
|
|
215
|
-
{
|
|
216
|
-
arma::SpMat<eT> tmp = trans(matrix);
|
|
217
|
-
success = SaveMatrix(tmp, opts, filename, stream);
|
|
218
|
-
}
|
|
219
|
-
else
|
|
220
|
-
success = SaveMatrix(matrix, opts, filename, stream);
|
|
221
|
-
|
|
222
|
-
return success;
|
|
223
|
-
}
|
|
224
|
-
|
|
225
|
-
//! Save a model to file.
|
|
226
|
-
template<typename T>
|
|
227
|
-
bool Save(const std::string& filename,
|
|
228
|
-
const std::string& name,
|
|
229
|
-
T& t,
|
|
230
|
-
const bool fatal,
|
|
231
|
-
format f,
|
|
232
|
-
std::enable_if_t<HasSerialize<T>::value>*)
|
|
233
|
-
{
|
|
234
|
-
if (f == format::autodetect)
|
|
235
|
-
{
|
|
236
|
-
std::string extension = Extension(filename);
|
|
237
|
-
|
|
238
|
-
if (extension == "xml")
|
|
239
|
-
f = format::xml;
|
|
240
|
-
else if (extension == "bin")
|
|
241
|
-
f = format::binary;
|
|
242
|
-
else if (extension == "json")
|
|
243
|
-
f = format::json;
|
|
244
|
-
else
|
|
245
|
-
{
|
|
246
|
-
if (fatal)
|
|
247
|
-
Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
|
|
248
|
-
<< " extension? (allowed: xml/bin/json)" << std::endl;
|
|
249
|
-
else
|
|
250
|
-
Log::Warn << "Unable to detect type of '" << filename << "'; save "
|
|
251
|
-
<< "failed. Incorrect extension? (allowed: xml/bin/json)"
|
|
252
|
-
<< std::endl;
|
|
253
|
-
|
|
254
|
-
return false;
|
|
255
|
-
}
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
// Open the file to save to.
|
|
259
|
-
std::ofstream ofs;
|
|
260
|
-
#ifdef _WIN32
|
|
261
|
-
if (f == format::binary) // Open non-text types in binary mode on Windows.
|
|
262
|
-
ofs.open(filename, std::ofstream::out | std::ofstream::binary);
|
|
263
|
-
else
|
|
264
|
-
ofs.open(filename, std::ofstream::out);
|
|
265
|
-
#else
|
|
266
|
-
ofs.open(filename, std::ofstream::out);
|
|
267
|
-
#endif
|
|
268
|
-
|
|
269
|
-
if (!ofs.is_open())
|
|
270
|
-
{
|
|
271
|
-
if (fatal)
|
|
272
|
-
Log::Fatal << "Unable to open file '" << filename << "' to save object '"
|
|
273
|
-
<< name << "'." << std::endl;
|
|
274
|
-
else
|
|
275
|
-
Log::Warn << "Unable to open file '" << filename << "' to save object '"
|
|
276
|
-
<< name << "'." << std::endl;
|
|
277
|
-
|
|
278
|
-
return false;
|
|
279
|
-
}
|
|
280
|
-
|
|
281
|
-
try
|
|
282
|
-
{
|
|
283
|
-
if (f == format::xml)
|
|
284
|
-
{
|
|
285
|
-
cereal::XMLOutputArchive ar(ofs);
|
|
286
|
-
ar(cereal::make_nvp(name.c_str(), t));
|
|
287
|
-
}
|
|
288
|
-
else if (f == format::json)
|
|
289
|
-
{
|
|
290
|
-
cereal::JSONOutputArchive ar(ofs);
|
|
291
|
-
ar(cereal::make_nvp(name.c_str(), t));
|
|
292
|
-
}
|
|
293
|
-
else if (f == format::binary)
|
|
294
|
-
{
|
|
295
|
-
cereal::BinaryOutputArchive ar(ofs);
|
|
296
|
-
ar(cereal::make_nvp(name.c_str(), t));
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
return true;
|
|
300
|
-
}
|
|
301
|
-
catch (cereal::Exception& e)
|
|
302
|
-
{
|
|
303
|
-
if (fatal)
|
|
304
|
-
Log::Fatal << e.what() << std::endl;
|
|
305
|
-
else
|
|
306
|
-
Log::Warn << e.what() << std::endl;
|
|
307
|
-
|
|
308
|
-
return false;
|
|
309
|
-
}
|
|
310
|
-
}
|
|
311
|
-
|
|
312
|
-
} // namespace data
|
|
313
|
-
} // namespace mlpack
|
|
314
|
-
|
|
315
|
-
#endif
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/save_impl.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
*
|
|
6
|
+
* Implementation of save functionality.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_DATA_SAVE_IMPL_HPP
|
|
14
|
+
#define MLPACK_CORE_DATA_SAVE_IMPL_HPP
|
|
15
|
+
|
|
16
|
+
// In case it hasn't already been included.
|
|
17
|
+
#include "save.hpp"
|
|
18
|
+
#include "extension.hpp"
|
|
19
|
+
|
|
20
|
+
namespace mlpack {
|
|
21
|
+
namespace data {
|
|
22
|
+
|
|
23
|
+
template<typename eT>
|
|
24
|
+
bool Save(const std::string& filename,
|
|
25
|
+
const arma::Col<eT>& vec,
|
|
26
|
+
const bool fatal,
|
|
27
|
+
FileType inputSaveType)
|
|
28
|
+
{
|
|
29
|
+
// Don't transpose: one observation per line (for CSVs at least).
|
|
30
|
+
return Save(filename, vec, fatal, false, inputSaveType);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
template<typename eT>
|
|
34
|
+
bool Save(const std::string& filename,
|
|
35
|
+
const arma::Row<eT>& rowvec,
|
|
36
|
+
const bool fatal,
|
|
37
|
+
FileType inputSaveType)
|
|
38
|
+
{
|
|
39
|
+
return Save(filename, rowvec, fatal, true, inputSaveType);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// Save a Sparse Matrix
|
|
43
|
+
template<typename eT>
|
|
44
|
+
bool Save(const std::string& filename,
|
|
45
|
+
const arma::SpMat<eT>& matrix,
|
|
46
|
+
const bool fatal,
|
|
47
|
+
bool transpose)
|
|
48
|
+
{
|
|
49
|
+
MatrixOptions opts;
|
|
50
|
+
opts.Fatal() = fatal;
|
|
51
|
+
opts.NoTranspose() = !transpose;
|
|
52
|
+
|
|
53
|
+
return Save(filename, matrix, opts);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
template<typename eT>
|
|
57
|
+
bool Save(const std::string& filename,
|
|
58
|
+
const arma::Mat<eT>& matrix,
|
|
59
|
+
const bool fatal,
|
|
60
|
+
bool transpose,
|
|
61
|
+
FileType inputSaveType)
|
|
62
|
+
{
|
|
63
|
+
MatrixOptions opts;
|
|
64
|
+
opts.Fatal() = fatal;
|
|
65
|
+
opts.NoTranspose() = !transpose;
|
|
66
|
+
opts.Format() = inputSaveType;
|
|
67
|
+
|
|
68
|
+
return Save(filename, matrix, opts);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
template<typename MatType, typename DataOptionsType>
|
|
72
|
+
bool Save(const std::string& filename,
|
|
73
|
+
const MatType& matrix,
|
|
74
|
+
const DataOptionsType& opts,
|
|
75
|
+
std::enable_if_t<IsArma<MatType>::value ||
|
|
76
|
+
IsSparseMat<MatType>::value>*)
|
|
77
|
+
{
|
|
78
|
+
//! just use default copy ctor with = operator and make a copy.
|
|
79
|
+
DataOptionsType copyOpts(opts);
|
|
80
|
+
return Save(filename, matrix, copyOpts);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/*
|
|
84
|
+
* Add this SFINAE in here because the compiler is so stupid that it is not
|
|
85
|
+
* able to distinguish between these two:
|
|
86
|
+
*
|
|
87
|
+
* data::Save(filename, "model", *output);
|
|
88
|
+
*
|
|
89
|
+
* and
|
|
90
|
+
*
|
|
91
|
+
* data::Save(filename, matrix, opts);
|
|
92
|
+
*
|
|
93
|
+
* The second SFINAE is added because the compiler is bot able to see the
|
|
94
|
+
* difference between:
|
|
95
|
+
*
|
|
96
|
+
* data::Save(filename, Row/Col, fatal);
|
|
97
|
+
*
|
|
98
|
+
* and
|
|
99
|
+
*
|
|
100
|
+
* data::Save(filename, Row/Col, Opts);
|
|
101
|
+
*
|
|
102
|
+
* This SFINAE is temporary and must be removed after the integration of stage 3 or
|
|
103
|
+
* when the compiler becomes more intelligent.
|
|
104
|
+
*/
|
|
105
|
+
template<typename MatType, typename DataOptionsType>
|
|
106
|
+
bool Save(const std::string& filename,
|
|
107
|
+
const MatType& matrix,
|
|
108
|
+
DataOptionsType& opts,
|
|
109
|
+
std::enable_if_t<IsArma<MatType>::value ||
|
|
110
|
+
IsSparseMat<MatType>::value>*,
|
|
111
|
+
std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
|
|
112
|
+
{
|
|
113
|
+
Timer::Start("saving_data");
|
|
114
|
+
|
|
115
|
+
bool success = DetectFileType<MatType>(filename, opts, false);
|
|
116
|
+
if (!success)
|
|
117
|
+
{
|
|
118
|
+
Timer::Stop("saving_data");
|
|
119
|
+
return false;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
std::fstream stream;
|
|
123
|
+
success = OpenFile(filename, opts, false, stream);
|
|
124
|
+
if (!success)
|
|
125
|
+
{
|
|
126
|
+
Timer::Stop("saving_data");
|
|
127
|
+
return false;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// Try to save the file.
|
|
131
|
+
Log::Info << "Saving " << opts.FileTypeToString() << " to '" << filename
|
|
132
|
+
<< "'." << std::endl;
|
|
133
|
+
if constexpr (IsArma<MatType>::value || IsSparseMat<MatType>::value)
|
|
134
|
+
{
|
|
135
|
+
TextOptions txtOpts(std::move(opts));
|
|
136
|
+
if constexpr (IsSparseMat<MatType>::value)
|
|
137
|
+
{
|
|
138
|
+
success = SaveSparse(matrix, txtOpts, filename, stream);
|
|
139
|
+
}
|
|
140
|
+
else if constexpr (IsCol<MatType>::value)
|
|
141
|
+
{
|
|
142
|
+
opts.NoTranspose() = true;
|
|
143
|
+
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
144
|
+
}
|
|
145
|
+
else if constexpr (IsRow<MatType>::value)
|
|
146
|
+
{
|
|
147
|
+
opts.NoTranspose() = false;
|
|
148
|
+
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
149
|
+
}
|
|
150
|
+
else if constexpr (IsDense<MatType>::value)
|
|
151
|
+
{
|
|
152
|
+
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
153
|
+
}
|
|
154
|
+
opts = std::move(txtOpts);
|
|
155
|
+
}
|
|
156
|
+
else
|
|
157
|
+
{
|
|
158
|
+
if (opts.Fatal())
|
|
159
|
+
Log::Fatal << "DataOptionsType is unknown! Please use a known type or "
|
|
160
|
+
<< "or provide specific overloads." << std::endl;
|
|
161
|
+
else
|
|
162
|
+
Log::Warn << "DataOptionsType is unknown! Please use a known type or "
|
|
163
|
+
<< "or provide specific overloads." << std::endl;
|
|
164
|
+
|
|
165
|
+
return false;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
if (!success)
|
|
169
|
+
{
|
|
170
|
+
Timer::Stop("saving_data");
|
|
171
|
+
if (opts.Fatal())
|
|
172
|
+
Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
|
|
173
|
+
else
|
|
174
|
+
Log::Warn << "Save to '" << filename << "' failed." << std::endl;
|
|
175
|
+
return false;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
Timer::Stop("saving_data");
|
|
179
|
+
|
|
180
|
+
return success;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
template<typename eT>
|
|
184
|
+
bool SaveDense(const arma::Mat<eT>& matrix,
|
|
185
|
+
TextOptions& opts,
|
|
186
|
+
const std::string& filename,
|
|
187
|
+
std::fstream& stream)
|
|
188
|
+
{
|
|
189
|
+
bool success = false;
|
|
190
|
+
arma::Mat<eT> tmp;
|
|
191
|
+
// Transpose the matrix.
|
|
192
|
+
if (!opts.NoTranspose())
|
|
193
|
+
{
|
|
194
|
+
tmp = trans(matrix);
|
|
195
|
+
success = SaveMatrix(tmp, opts, filename, stream);
|
|
196
|
+
}
|
|
197
|
+
else
|
|
198
|
+
success = SaveMatrix(matrix, opts, filename, stream);
|
|
199
|
+
|
|
200
|
+
return success;
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
// Save a Sparse Matrix
|
|
204
|
+
template<typename eT>
|
|
205
|
+
bool SaveSparse(const arma::SpMat<eT>& matrix,
|
|
206
|
+
TextOptions& opts,
|
|
207
|
+
const std::string& filename,
|
|
208
|
+
std::fstream& stream)
|
|
209
|
+
{
|
|
210
|
+
bool success = false;
|
|
211
|
+
arma::SpMat<eT> tmp;
|
|
212
|
+
|
|
213
|
+
// Transpose the matrix.
|
|
214
|
+
if (!opts.NoTranspose())
|
|
215
|
+
{
|
|
216
|
+
arma::SpMat<eT> tmp = trans(matrix);
|
|
217
|
+
success = SaveMatrix(tmp, opts, filename, stream);
|
|
218
|
+
}
|
|
219
|
+
else
|
|
220
|
+
success = SaveMatrix(matrix, opts, filename, stream);
|
|
221
|
+
|
|
222
|
+
return success;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
//! Save a model to file.
|
|
226
|
+
template<typename T>
|
|
227
|
+
bool Save(const std::string& filename,
|
|
228
|
+
const std::string& name,
|
|
229
|
+
T& t,
|
|
230
|
+
const bool fatal,
|
|
231
|
+
format f,
|
|
232
|
+
std::enable_if_t<HasSerialize<T>::value>*)
|
|
233
|
+
{
|
|
234
|
+
if (f == format::autodetect)
|
|
235
|
+
{
|
|
236
|
+
std::string extension = Extension(filename);
|
|
237
|
+
|
|
238
|
+
if (extension == "xml")
|
|
239
|
+
f = format::xml;
|
|
240
|
+
else if (extension == "bin")
|
|
241
|
+
f = format::binary;
|
|
242
|
+
else if (extension == "json")
|
|
243
|
+
f = format::json;
|
|
244
|
+
else
|
|
245
|
+
{
|
|
246
|
+
if (fatal)
|
|
247
|
+
Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
|
|
248
|
+
<< " extension? (allowed: xml/bin/json)" << std::endl;
|
|
249
|
+
else
|
|
250
|
+
Log::Warn << "Unable to detect type of '" << filename << "'; save "
|
|
251
|
+
<< "failed. Incorrect extension? (allowed: xml/bin/json)"
|
|
252
|
+
<< std::endl;
|
|
253
|
+
|
|
254
|
+
return false;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
// Open the file to save to.
|
|
259
|
+
std::ofstream ofs;
|
|
260
|
+
#ifdef _WIN32
|
|
261
|
+
if (f == format::binary) // Open non-text types in binary mode on Windows.
|
|
262
|
+
ofs.open(filename, std::ofstream::out | std::ofstream::binary);
|
|
263
|
+
else
|
|
264
|
+
ofs.open(filename, std::ofstream::out);
|
|
265
|
+
#else
|
|
266
|
+
ofs.open(filename, std::ofstream::out);
|
|
267
|
+
#endif
|
|
268
|
+
|
|
269
|
+
if (!ofs.is_open())
|
|
270
|
+
{
|
|
271
|
+
if (fatal)
|
|
272
|
+
Log::Fatal << "Unable to open file '" << filename << "' to save object '"
|
|
273
|
+
<< name << "'." << std::endl;
|
|
274
|
+
else
|
|
275
|
+
Log::Warn << "Unable to open file '" << filename << "' to save object '"
|
|
276
|
+
<< name << "'." << std::endl;
|
|
277
|
+
|
|
278
|
+
return false;
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
try
|
|
282
|
+
{
|
|
283
|
+
if (f == format::xml)
|
|
284
|
+
{
|
|
285
|
+
cereal::XMLOutputArchive ar(ofs);
|
|
286
|
+
ar(cereal::make_nvp(name.c_str(), t));
|
|
287
|
+
}
|
|
288
|
+
else if (f == format::json)
|
|
289
|
+
{
|
|
290
|
+
cereal::JSONOutputArchive ar(ofs);
|
|
291
|
+
ar(cereal::make_nvp(name.c_str(), t));
|
|
292
|
+
}
|
|
293
|
+
else if (f == format::binary)
|
|
294
|
+
{
|
|
295
|
+
cereal::BinaryOutputArchive ar(ofs);
|
|
296
|
+
ar(cereal::make_nvp(name.c_str(), t));
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
return true;
|
|
300
|
+
}
|
|
301
|
+
catch (cereal::Exception& e)
|
|
302
|
+
{
|
|
303
|
+
if (fatal)
|
|
304
|
+
Log::Fatal << e.what() << std::endl;
|
|
305
|
+
else
|
|
306
|
+
Log::Warn << e.what() << std::endl;
|
|
307
|
+
|
|
308
|
+
return false;
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
} // namespace data
|
|
313
|
+
} // namespace mlpack
|
|
314
|
+
|
|
315
|
+
#endif
|