mlpack 4.6.1__cp38-cp38-win_amd64.whl → 4.6.2__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. mlpack/__init__.py +3 -3
  2. mlpack/adaboost_classify.cp38-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp38-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp38-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp38-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp38-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp38-win_amd64.pyd +0 -0
  8. mlpack/cf.cp38-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp38-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp38-win_amd64.pyd +0 -0
  11. mlpack/det.cp38-win_amd64.pyd +0 -0
  12. mlpack/emst.cp38-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp38-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp38-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp38-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp38-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp38-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp38-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp38-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp38-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp38-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp38-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +21 -12
  24. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +49 -39
  25. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +9 -46
  26. mlpack/include/mlpack/core/data/save_impl.hpp +315 -315
  27. mlpack/include/mlpack/core/data/utilities.hpp +158 -158
  28. mlpack/include/mlpack/core/math/ccov.hpp +1 -0
  29. mlpack/include/mlpack/core/math/ccov_impl.hpp +4 -5
  30. mlpack/include/mlpack/core/math/make_alias.hpp +98 -3
  31. mlpack/include/mlpack/core/util/arma_traits.hpp +19 -2
  32. mlpack/include/mlpack/core/util/gitversion.hpp +1 -1
  33. mlpack/include/mlpack/core/util/sfinae_utility.hpp +24 -2
  34. mlpack/include/mlpack/core/util/version.hpp +1 -1
  35. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -2
  36. mlpack/include/mlpack/methods/ann/init_rules/network_init.hpp +5 -5
  37. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +3 -2
  38. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +19 -20
  39. mlpack/include/mlpack/methods/ann/layer/concat.hpp +1 -0
  40. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +6 -7
  41. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +3 -3
  42. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +3 -3
  43. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +1 -0
  44. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +11 -14
  45. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +5 -4
  46. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +15 -14
  47. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +3 -2
  48. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +14 -15
  49. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +6 -5
  50. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +24 -25
  51. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +1 -0
  52. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +4 -4
  53. mlpack/include/mlpack/methods/ann/layer/padding.hpp +1 -0
  54. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +12 -13
  55. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +3 -2
  56. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +5 -4
  57. mlpack/include/mlpack/methods/ann/rnn.hpp +19 -18
  58. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +15 -15
  59. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_impl.hpp +3 -8
  60. mlpack/include/mlpack/methods/decision_tree/fitness_functions/gini_gain.hpp +5 -8
  61. mlpack/include/mlpack/methods/decision_tree/fitness_functions/information_gain.hpp +5 -8
  62. mlpack/include/mlpack/methods/gmm/diagonal_gmm_impl.hpp +2 -1
  63. mlpack/include/mlpack/methods/gmm/eigenvalue_ratio_constraint.hpp +3 -3
  64. mlpack/include/mlpack/methods/gmm/gmm_impl.hpp +2 -1
  65. mlpack/include/mlpack/methods/hmm/hmm_impl.hpp +10 -5
  66. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +57 -37
  67. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +69 -59
  68. mlpack/kde.cp38-win_amd64.pyd +0 -0
  69. mlpack/kernel_pca.cp38-win_amd64.pyd +0 -0
  70. mlpack/kfn.cp38-win_amd64.pyd +0 -0
  71. mlpack/kmeans.cp38-win_amd64.pyd +0 -0
  72. mlpack/knn.cp38-win_amd64.pyd +0 -0
  73. mlpack/krann.cp38-win_amd64.pyd +0 -0
  74. mlpack/lars.cp38-win_amd64.pyd +0 -0
  75. mlpack/linear_regression_predict.cp38-win_amd64.pyd +0 -0
  76. mlpack/linear_regression_train.cp38-win_amd64.pyd +0 -0
  77. mlpack/linear_svm.cp38-win_amd64.pyd +0 -0
  78. mlpack/lmnn.cp38-win_amd64.pyd +0 -0
  79. mlpack/local_coordinate_coding.cp38-win_amd64.pyd +0 -0
  80. mlpack/logistic_regression.cp38-win_amd64.pyd +0 -0
  81. mlpack/lsh.cp38-win_amd64.pyd +0 -0
  82. mlpack/mean_shift.cp38-win_amd64.pyd +0 -0
  83. mlpack/nbc.cp38-win_amd64.pyd +0 -0
  84. mlpack/nca.cp38-win_amd64.pyd +0 -0
  85. mlpack/nmf.cp38-win_amd64.pyd +0 -0
  86. mlpack/pca.cp38-win_amd64.pyd +0 -0
  87. mlpack/perceptron.cp38-win_amd64.pyd +0 -0
  88. mlpack/preprocess_binarize.cp38-win_amd64.pyd +0 -0
  89. mlpack/preprocess_describe.cp38-win_amd64.pyd +0 -0
  90. mlpack/preprocess_one_hot_encoding.cp38-win_amd64.pyd +0 -0
  91. mlpack/preprocess_scale.cp38-win_amd64.pyd +0 -0
  92. mlpack/preprocess_split.cp38-win_amd64.pyd +0 -0
  93. mlpack/radical.cp38-win_amd64.pyd +0 -0
  94. mlpack/random_forest.cp38-win_amd64.pyd +0 -0
  95. mlpack/softmax_regression.cp38-win_amd64.pyd +0 -0
  96. mlpack/sparse_coding.cp38-win_amd64.pyd +0 -0
  97. mlpack-4.6.2.dist-info/DELVEWHEEL +2 -0
  98. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/METADATA +2 -2
  99. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/RECORD +102 -102
  100. mlpack-4.6.1.dist-info/DELVEWHEEL +0 -2
  101. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/WHEEL +0 -0
  102. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/top_level.txt +0 -0
  103. mlpack.libs/{.load-order-mlpack-4.6.1 → .load-order-mlpack-4.6.2} +1 -1
@@ -1,315 +1,315 @@
1
- /**
2
- * @file core/data/save_impl.hpp
3
- * @author Ryan Curtin
4
- * @author Omar Shrit
5
- *
6
- * Implementation of save functionality.
7
- *
8
- * mlpack is free software; you may redistribute it and/or modify it under the
9
- * terms of the 3-clause BSD license. You should have received a copy of the
10
- * 3-clause BSD license along with mlpack. If not, see
11
- * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
- */
13
- #ifndef MLPACK_CORE_DATA_SAVE_IMPL_HPP
14
- #define MLPACK_CORE_DATA_SAVE_IMPL_HPP
15
-
16
- // In case it hasn't already been included.
17
- #include "save.hpp"
18
- #include "extension.hpp"
19
-
20
- namespace mlpack {
21
- namespace data {
22
-
23
- template<typename eT>
24
- bool Save(const std::string& filename,
25
- const arma::Col<eT>& vec,
26
- const bool fatal,
27
- FileType inputSaveType)
28
- {
29
- // Don't transpose: one observation per line (for CSVs at least).
30
- return Save(filename, vec, fatal, false, inputSaveType);
31
- }
32
-
33
- template<typename eT>
34
- bool Save(const std::string& filename,
35
- const arma::Row<eT>& rowvec,
36
- const bool fatal,
37
- FileType inputSaveType)
38
- {
39
- return Save(filename, rowvec, fatal, true, inputSaveType);
40
- }
41
-
42
- // Save a Sparse Matrix
43
- template<typename eT>
44
- bool Save(const std::string& filename,
45
- const arma::SpMat<eT>& matrix,
46
- const bool fatal,
47
- bool transpose)
48
- {
49
- MatrixOptions opts;
50
- opts.Fatal() = fatal;
51
- opts.NoTranspose() = !transpose;
52
-
53
- return Save(filename, matrix, opts);
54
- }
55
-
56
- template<typename eT>
57
- bool Save(const std::string& filename,
58
- const arma::Mat<eT>& matrix,
59
- const bool fatal,
60
- bool transpose,
61
- FileType inputSaveType)
62
- {
63
- MatrixOptions opts;
64
- opts.Fatal() = fatal;
65
- opts.NoTranspose() = !transpose;
66
- opts.Format() = inputSaveType;
67
-
68
- return Save(filename, matrix, opts);
69
- }
70
-
71
- template<typename MatType, typename DataOptionsType>
72
- bool Save(const std::string& filename,
73
- const MatType& matrix,
74
- const DataOptionsType& opts,
75
- std::enable_if_t<IsArma<MatType>::value ||
76
- IsSparseMat<MatType>::value>*)
77
- {
78
- //! just use default copy ctor with = operator and make a copy.
79
- DataOptionsType copyOpts(opts);
80
- return Save(filename, matrix, copyOpts);
81
- }
82
-
83
- /*
84
- * Add this SFINAE in here because the compiler is so stupid that it is not
85
- * able to distinguish between these two:
86
- *
87
- * data::Save(filename, "model", *output);
88
- *
89
- * and
90
- *
91
- * data::Save(filename, matrix, opts);
92
- *
93
- * The second SFINAE is added because the compiler is bot able to see the
94
- * difference between:
95
- *
96
- * data::Save(filename, Row/Col, fatal);
97
- *
98
- * and
99
- *
100
- * data::Save(filename, Row/Col, Opts);
101
- *
102
- * This SFINAE is temporary and must be removed after the integration of stage 3 or
103
- * when the compiler becomes more intelligent.
104
- */
105
- template<typename MatType, typename DataOptionsType>
106
- bool Save(const std::string& filename,
107
- const MatType& matrix,
108
- DataOptionsType& opts,
109
- std::enable_if_t<IsArma<MatType>::value ||
110
- IsSparseMat<MatType>::value>*,
111
- std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
112
- {
113
- Timer::Start("saving_data");
114
-
115
- bool success = DetectFileType<MatType>(filename, opts, false);
116
- if (!success)
117
- {
118
- Timer::Stop("saving_data");
119
- return false;
120
- }
121
-
122
- std::fstream stream;
123
- success = OpenFile(filename, opts, false, stream);
124
- if (!success)
125
- {
126
- Timer::Stop("saving_data");
127
- return false;
128
- }
129
-
130
- // Try to save the file.
131
- Log::Info << "Saving " << opts.FileTypeToString() << " to '" << filename
132
- << "'." << std::endl;
133
- if constexpr (IsArma<MatType>::value || IsSparseMat<MatType>::value)
134
- {
135
- TextOptions txtOpts(std::move(opts));
136
- if constexpr (IsSparseMat<MatType>::value)
137
- {
138
- success = SaveSparse(matrix, txtOpts, filename, stream);
139
- }
140
- else if constexpr (IsCol<MatType>::value)
141
- {
142
- opts.NoTranspose() = true;
143
- success = SaveDense(matrix, txtOpts, filename, stream);
144
- }
145
- else if constexpr (IsRow<MatType>::value)
146
- {
147
- opts.NoTranspose() = false;
148
- success = SaveDense(matrix, txtOpts, filename, stream);
149
- }
150
- else if constexpr (IsDense<MatType>::value)
151
- {
152
- success = SaveDense(matrix, txtOpts, filename, stream);
153
- }
154
- opts = std::move(txtOpts);
155
- }
156
- else
157
- {
158
- if (opts.Fatal())
159
- Log::Fatal << "DataOptionsType is unknown! Please use a known type or "
160
- << "or provide specific overloads." << std::endl;
161
- else
162
- Log::Warn << "DataOptionsType is unknown! Please use a known type or "
163
- << "or provide specific overloads." << std::endl;
164
-
165
- return false;
166
- }
167
-
168
- if (!success)
169
- {
170
- Timer::Stop("saving_data");
171
- if (opts.Fatal())
172
- Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
173
- else
174
- Log::Warn << "Save to '" << filename << "' failed." << std::endl;
175
- return false;
176
- }
177
-
178
- Timer::Stop("saving_data");
179
-
180
- return success;
181
- }
182
-
183
- template<typename eT>
184
- bool SaveDense(const arma::Mat<eT>& matrix,
185
- TextOptions& opts,
186
- const std::string& filename,
187
- std::fstream& stream)
188
- {
189
- bool success = false;
190
- arma::Mat<eT> tmp;
191
- // Transpose the matrix.
192
- if (!opts.NoTranspose())
193
- {
194
- tmp = trans(matrix);
195
- success = SaveMatrix(tmp, opts, filename, stream);
196
- }
197
- else
198
- success = SaveMatrix(matrix, opts, filename, stream);
199
-
200
- return success;
201
- }
202
-
203
- // Save a Sparse Matrix
204
- template<typename eT>
205
- bool SaveSparse(const arma::SpMat<eT>& matrix,
206
- TextOptions& opts,
207
- const std::string& filename,
208
- std::fstream& stream)
209
- {
210
- bool success = false;
211
- arma::SpMat<eT> tmp;
212
-
213
- // Transpose the matrix.
214
- if (!opts.NoTranspose())
215
- {
216
- arma::SpMat<eT> tmp = trans(matrix);
217
- success = SaveMatrix(tmp, opts, filename, stream);
218
- }
219
- else
220
- success = SaveMatrix(matrix, opts, filename, stream);
221
-
222
- return success;
223
- }
224
-
225
- //! Save a model to file.
226
- template<typename T>
227
- bool Save(const std::string& filename,
228
- const std::string& name,
229
- T& t,
230
- const bool fatal,
231
- format f,
232
- std::enable_if_t<HasSerialize<T>::value>*)
233
- {
234
- if (f == format::autodetect)
235
- {
236
- std::string extension = Extension(filename);
237
-
238
- if (extension == "xml")
239
- f = format::xml;
240
- else if (extension == "bin")
241
- f = format::binary;
242
- else if (extension == "json")
243
- f = format::json;
244
- else
245
- {
246
- if (fatal)
247
- Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
248
- << " extension? (allowed: xml/bin/json)" << std::endl;
249
- else
250
- Log::Warn << "Unable to detect type of '" << filename << "'; save "
251
- << "failed. Incorrect extension? (allowed: xml/bin/json)"
252
- << std::endl;
253
-
254
- return false;
255
- }
256
- }
257
-
258
- // Open the file to save to.
259
- std::ofstream ofs;
260
- #ifdef _WIN32
261
- if (f == format::binary) // Open non-text types in binary mode on Windows.
262
- ofs.open(filename, std::ofstream::out | std::ofstream::binary);
263
- else
264
- ofs.open(filename, std::ofstream::out);
265
- #else
266
- ofs.open(filename, std::ofstream::out);
267
- #endif
268
-
269
- if (!ofs.is_open())
270
- {
271
- if (fatal)
272
- Log::Fatal << "Unable to open file '" << filename << "' to save object '"
273
- << name << "'." << std::endl;
274
- else
275
- Log::Warn << "Unable to open file '" << filename << "' to save object '"
276
- << name << "'." << std::endl;
277
-
278
- return false;
279
- }
280
-
281
- try
282
- {
283
- if (f == format::xml)
284
- {
285
- cereal::XMLOutputArchive ar(ofs);
286
- ar(cereal::make_nvp(name.c_str(), t));
287
- }
288
- else if (f == format::json)
289
- {
290
- cereal::JSONOutputArchive ar(ofs);
291
- ar(cereal::make_nvp(name.c_str(), t));
292
- }
293
- else if (f == format::binary)
294
- {
295
- cereal::BinaryOutputArchive ar(ofs);
296
- ar(cereal::make_nvp(name.c_str(), t));
297
- }
298
-
299
- return true;
300
- }
301
- catch (cereal::Exception& e)
302
- {
303
- if (fatal)
304
- Log::Fatal << e.what() << std::endl;
305
- else
306
- Log::Warn << e.what() << std::endl;
307
-
308
- return false;
309
- }
310
- }
311
-
312
- } // namespace data
313
- } // namespace mlpack
314
-
315
- #endif
1
+ /**
2
+ * @file core/data/save_impl.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ *
6
+ * Implementation of save functionality.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_CORE_DATA_SAVE_IMPL_HPP
14
+ #define MLPACK_CORE_DATA_SAVE_IMPL_HPP
15
+
16
+ // In case it hasn't already been included.
17
+ #include "save.hpp"
18
+ #include "extension.hpp"
19
+
20
+ namespace mlpack {
21
+ namespace data {
22
+
23
+ template<typename eT>
24
+ bool Save(const std::string& filename,
25
+ const arma::Col<eT>& vec,
26
+ const bool fatal,
27
+ FileType inputSaveType)
28
+ {
29
+ // Don't transpose: one observation per line (for CSVs at least).
30
+ return Save(filename, vec, fatal, false, inputSaveType);
31
+ }
32
+
33
+ template<typename eT>
34
+ bool Save(const std::string& filename,
35
+ const arma::Row<eT>& rowvec,
36
+ const bool fatal,
37
+ FileType inputSaveType)
38
+ {
39
+ return Save(filename, rowvec, fatal, true, inputSaveType);
40
+ }
41
+
42
+ // Save a Sparse Matrix
43
+ template<typename eT>
44
+ bool Save(const std::string& filename,
45
+ const arma::SpMat<eT>& matrix,
46
+ const bool fatal,
47
+ bool transpose)
48
+ {
49
+ MatrixOptions opts;
50
+ opts.Fatal() = fatal;
51
+ opts.NoTranspose() = !transpose;
52
+
53
+ return Save(filename, matrix, opts);
54
+ }
55
+
56
+ template<typename eT>
57
+ bool Save(const std::string& filename,
58
+ const arma::Mat<eT>& matrix,
59
+ const bool fatal,
60
+ bool transpose,
61
+ FileType inputSaveType)
62
+ {
63
+ MatrixOptions opts;
64
+ opts.Fatal() = fatal;
65
+ opts.NoTranspose() = !transpose;
66
+ opts.Format() = inputSaveType;
67
+
68
+ return Save(filename, matrix, opts);
69
+ }
70
+
71
+ template<typename MatType, typename DataOptionsType>
72
+ bool Save(const std::string& filename,
73
+ const MatType& matrix,
74
+ const DataOptionsType& opts,
75
+ std::enable_if_t<IsArma<MatType>::value ||
76
+ IsSparseMat<MatType>::value>*)
77
+ {
78
+ //! just use default copy ctor with = operator and make a copy.
79
+ DataOptionsType copyOpts(opts);
80
+ return Save(filename, matrix, copyOpts);
81
+ }
82
+
83
+ /*
84
+ * Add this SFINAE in here because the compiler is so stupid that it is not
85
+ * able to distinguish between these two:
86
+ *
87
+ * data::Save(filename, "model", *output);
88
+ *
89
+ * and
90
+ *
91
+ * data::Save(filename, matrix, opts);
92
+ *
93
+ * The second SFINAE is added because the compiler is bot able to see the
94
+ * difference between:
95
+ *
96
+ * data::Save(filename, Row/Col, fatal);
97
+ *
98
+ * and
99
+ *
100
+ * data::Save(filename, Row/Col, Opts);
101
+ *
102
+ * This SFINAE is temporary and must be removed after the integration of stage 3 or
103
+ * when the compiler becomes more intelligent.
104
+ */
105
+ template<typename MatType, typename DataOptionsType>
106
+ bool Save(const std::string& filename,
107
+ const MatType& matrix,
108
+ DataOptionsType& opts,
109
+ std::enable_if_t<IsArma<MatType>::value ||
110
+ IsSparseMat<MatType>::value>*,
111
+ std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
112
+ {
113
+ Timer::Start("saving_data");
114
+
115
+ bool success = DetectFileType<MatType>(filename, opts, false);
116
+ if (!success)
117
+ {
118
+ Timer::Stop("saving_data");
119
+ return false;
120
+ }
121
+
122
+ std::fstream stream;
123
+ success = OpenFile(filename, opts, false, stream);
124
+ if (!success)
125
+ {
126
+ Timer::Stop("saving_data");
127
+ return false;
128
+ }
129
+
130
+ // Try to save the file.
131
+ Log::Info << "Saving " << opts.FileTypeToString() << " to '" << filename
132
+ << "'." << std::endl;
133
+ if constexpr (IsArma<MatType>::value || IsSparseMat<MatType>::value)
134
+ {
135
+ TextOptions txtOpts(std::move(opts));
136
+ if constexpr (IsSparseMat<MatType>::value)
137
+ {
138
+ success = SaveSparse(matrix, txtOpts, filename, stream);
139
+ }
140
+ else if constexpr (IsCol<MatType>::value)
141
+ {
142
+ opts.NoTranspose() = true;
143
+ success = SaveDense(matrix, txtOpts, filename, stream);
144
+ }
145
+ else if constexpr (IsRow<MatType>::value)
146
+ {
147
+ opts.NoTranspose() = false;
148
+ success = SaveDense(matrix, txtOpts, filename, stream);
149
+ }
150
+ else if constexpr (IsDense<MatType>::value)
151
+ {
152
+ success = SaveDense(matrix, txtOpts, filename, stream);
153
+ }
154
+ opts = std::move(txtOpts);
155
+ }
156
+ else
157
+ {
158
+ if (opts.Fatal())
159
+ Log::Fatal << "DataOptionsType is unknown! Please use a known type or "
160
+ << "or provide specific overloads." << std::endl;
161
+ else
162
+ Log::Warn << "DataOptionsType is unknown! Please use a known type or "
163
+ << "or provide specific overloads." << std::endl;
164
+
165
+ return false;
166
+ }
167
+
168
+ if (!success)
169
+ {
170
+ Timer::Stop("saving_data");
171
+ if (opts.Fatal())
172
+ Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
173
+ else
174
+ Log::Warn << "Save to '" << filename << "' failed." << std::endl;
175
+ return false;
176
+ }
177
+
178
+ Timer::Stop("saving_data");
179
+
180
+ return success;
181
+ }
182
+
183
+ template<typename eT>
184
+ bool SaveDense(const arma::Mat<eT>& matrix,
185
+ TextOptions& opts,
186
+ const std::string& filename,
187
+ std::fstream& stream)
188
+ {
189
+ bool success = false;
190
+ arma::Mat<eT> tmp;
191
+ // Transpose the matrix.
192
+ if (!opts.NoTranspose())
193
+ {
194
+ tmp = trans(matrix);
195
+ success = SaveMatrix(tmp, opts, filename, stream);
196
+ }
197
+ else
198
+ success = SaveMatrix(matrix, opts, filename, stream);
199
+
200
+ return success;
201
+ }
202
+
203
+ // Save a Sparse Matrix
204
+ template<typename eT>
205
+ bool SaveSparse(const arma::SpMat<eT>& matrix,
206
+ TextOptions& opts,
207
+ const std::string& filename,
208
+ std::fstream& stream)
209
+ {
210
+ bool success = false;
211
+ arma::SpMat<eT> tmp;
212
+
213
+ // Transpose the matrix.
214
+ if (!opts.NoTranspose())
215
+ {
216
+ arma::SpMat<eT> tmp = trans(matrix);
217
+ success = SaveMatrix(tmp, opts, filename, stream);
218
+ }
219
+ else
220
+ success = SaveMatrix(matrix, opts, filename, stream);
221
+
222
+ return success;
223
+ }
224
+
225
+ //! Save a model to file.
226
+ template<typename T>
227
+ bool Save(const std::string& filename,
228
+ const std::string& name,
229
+ T& t,
230
+ const bool fatal,
231
+ format f,
232
+ std::enable_if_t<HasSerialize<T>::value>*)
233
+ {
234
+ if (f == format::autodetect)
235
+ {
236
+ std::string extension = Extension(filename);
237
+
238
+ if (extension == "xml")
239
+ f = format::xml;
240
+ else if (extension == "bin")
241
+ f = format::binary;
242
+ else if (extension == "json")
243
+ f = format::json;
244
+ else
245
+ {
246
+ if (fatal)
247
+ Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
248
+ << " extension? (allowed: xml/bin/json)" << std::endl;
249
+ else
250
+ Log::Warn << "Unable to detect type of '" << filename << "'; save "
251
+ << "failed. Incorrect extension? (allowed: xml/bin/json)"
252
+ << std::endl;
253
+
254
+ return false;
255
+ }
256
+ }
257
+
258
+ // Open the file to save to.
259
+ std::ofstream ofs;
260
+ #ifdef _WIN32
261
+ if (f == format::binary) // Open non-text types in binary mode on Windows.
262
+ ofs.open(filename, std::ofstream::out | std::ofstream::binary);
263
+ else
264
+ ofs.open(filename, std::ofstream::out);
265
+ #else
266
+ ofs.open(filename, std::ofstream::out);
267
+ #endif
268
+
269
+ if (!ofs.is_open())
270
+ {
271
+ if (fatal)
272
+ Log::Fatal << "Unable to open file '" << filename << "' to save object '"
273
+ << name << "'." << std::endl;
274
+ else
275
+ Log::Warn << "Unable to open file '" << filename << "' to save object '"
276
+ << name << "'." << std::endl;
277
+
278
+ return false;
279
+ }
280
+
281
+ try
282
+ {
283
+ if (f == format::xml)
284
+ {
285
+ cereal::XMLOutputArchive ar(ofs);
286
+ ar(cereal::make_nvp(name.c_str(), t));
287
+ }
288
+ else if (f == format::json)
289
+ {
290
+ cereal::JSONOutputArchive ar(ofs);
291
+ ar(cereal::make_nvp(name.c_str(), t));
292
+ }
293
+ else if (f == format::binary)
294
+ {
295
+ cereal::BinaryOutputArchive ar(ofs);
296
+ ar(cereal::make_nvp(name.c_str(), t));
297
+ }
298
+
299
+ return true;
300
+ }
301
+ catch (cereal::Exception& e)
302
+ {
303
+ if (fatal)
304
+ Log::Fatal << e.what() << std::endl;
305
+ else
306
+ Log::Warn << e.what() << std::endl;
307
+
308
+ return false;
309
+ }
310
+ }
311
+
312
+ } // namespace data
313
+ } // namespace mlpack
314
+
315
+ #endif