mlpack 4.6.1__cp313-cp313-win_amd64.whl → 4.6.2__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +1 -1
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +21 -12
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +49 -39
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +9 -46
- mlpack/include/mlpack/core/data/save_impl.hpp +315 -315
- mlpack/include/mlpack/core/data/utilities.hpp +158 -158
- mlpack/include/mlpack/core/math/ccov.hpp +1 -0
- mlpack/include/mlpack/core/math/ccov_impl.hpp +4 -5
- mlpack/include/mlpack/core/math/make_alias.hpp +98 -3
- mlpack/include/mlpack/core/util/arma_traits.hpp +19 -2
- mlpack/include/mlpack/core/util/gitversion.hpp +1 -1
- mlpack/include/mlpack/core/util/sfinae_utility.hpp +24 -2
- mlpack/include/mlpack/core/util/version.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -2
- mlpack/include/mlpack/methods/ann/init_rules/network_init.hpp +5 -5
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +3 -2
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +6 -7
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +11 -14
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +5 -4
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +15 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +3 -2
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +14 -15
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +6 -5
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +4 -4
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +1 -0
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +3 -2
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +5 -4
- mlpack/include/mlpack/methods/ann/rnn.hpp +19 -18
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +15 -15
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_impl.hpp +3 -8
- mlpack/include/mlpack/methods/decision_tree/fitness_functions/gini_gain.hpp +5 -8
- mlpack/include/mlpack/methods/decision_tree/fitness_functions/information_gain.hpp +5 -8
- mlpack/include/mlpack/methods/gmm/diagonal_gmm_impl.hpp +2 -1
- mlpack/include/mlpack/methods/gmm/eigenvalue_ratio_constraint.hpp +3 -3
- mlpack/include/mlpack/methods/gmm/gmm_impl.hpp +2 -1
- mlpack/include/mlpack/methods/hmm/hmm_impl.hpp +10 -5
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +57 -37
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +69 -59
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.6.2.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/METADATA +2 -2
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/RECORD +101 -101
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/WHEEL +1 -1
- mlpack-4.6.1.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/top_level.txt +0 -0
|
@@ -1,158 +1,158 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @file core/data/utilities.hpp
|
|
3
|
-
* @author Ryan Curtin
|
|
4
|
-
* @author Omar Shrit
|
|
5
|
-
* @author Gopi Tatiraju
|
|
6
|
-
*
|
|
7
|
-
* Utilities functions that can be used during loading and saving the data..
|
|
8
|
-
*
|
|
9
|
-
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
10
|
-
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
11
|
-
* 3-clause BSD license along with mlpack. If not, see
|
|
12
|
-
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
13
|
-
*/
|
|
14
|
-
#ifndef MLPACK_CORE_DATA_UTILITIES_HPP
|
|
15
|
-
#define MLPACK_CORE_DATA_UTILITIES_HPP
|
|
16
|
-
|
|
17
|
-
#include <mlpack/prereqs.hpp>
|
|
18
|
-
|
|
19
|
-
#include "detect_file_type.hpp"
|
|
20
|
-
|
|
21
|
-
namespace mlpack {
|
|
22
|
-
namespace data {
|
|
23
|
-
|
|
24
|
-
namespace details {
|
|
25
|
-
|
|
26
|
-
template<typename Tokenizer>
|
|
27
|
-
std::vector<std::string> ToTokens(Tokenizer& lineTok)
|
|
28
|
-
{
|
|
29
|
-
std::vector<std::string> tokens;
|
|
30
|
-
std::transform(std::begin(lineTok), std::end(lineTok),
|
|
31
|
-
std::back_inserter(tokens),
|
|
32
|
-
[&tokens](std::string const &str)
|
|
33
|
-
{
|
|
34
|
-
std::string trimmedToken(str);
|
|
35
|
-
Trim(trimmedToken);
|
|
36
|
-
return std::move(trimmedToken);
|
|
37
|
-
});
|
|
38
|
-
|
|
39
|
-
return tokens;
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
inline
|
|
43
|
-
void TransposeTokens(std::vector<std::vector<std::string>> const &input,
|
|
44
|
-
std::vector<std::string>& output,
|
|
45
|
-
size_t index)
|
|
46
|
-
{
|
|
47
|
-
output.clear();
|
|
48
|
-
for (size_t i = 0; i != input.size(); ++i)
|
|
49
|
-
{
|
|
50
|
-
output.emplace_back(input[i][index]);
|
|
51
|
-
}
|
|
52
|
-
}
|
|
53
|
-
} // namespace details
|
|
54
|
-
|
|
55
|
-
template<typename DataOptionsType>
|
|
56
|
-
bool OpenFile(const std::string& filename,
|
|
57
|
-
DataOptionsType& opts,
|
|
58
|
-
bool isLoading,
|
|
59
|
-
std::fstream& stream)
|
|
60
|
-
{
|
|
61
|
-
if (isLoading)
|
|
62
|
-
{
|
|
63
|
-
#ifdef _WIN32 // Always open in binary mode on Windows.
|
|
64
|
-
stream.open(filename.c_str(), std::fstream::in
|
|
65
|
-
| std::fstream::binary);
|
|
66
|
-
#else
|
|
67
|
-
stream.open(filename.c_str(), std::fstream::in);
|
|
68
|
-
#endif
|
|
69
|
-
}
|
|
70
|
-
// Add here and else if for ModelOptions in a couple of stages.
|
|
71
|
-
else
|
|
72
|
-
{
|
|
73
|
-
#ifdef _WIN32 // Always open in binary mode on Windows.
|
|
74
|
-
stream.open(filename.c_str(), std::fstream::out
|
|
75
|
-
| std::fstream::binary);
|
|
76
|
-
#else
|
|
77
|
-
stream.open(filename.c_str(), std::fstream::out);
|
|
78
|
-
#endif
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
if (!stream.is_open())
|
|
82
|
-
{
|
|
83
|
-
if (opts.Fatal() && isLoading)
|
|
84
|
-
Log::Fatal << "Cannot open file '" << filename << "' for loading. "
|
|
85
|
-
<< "Please check if the file exists." << std::endl;
|
|
86
|
-
|
|
87
|
-
else if (!opts.Fatal() && isLoading)
|
|
88
|
-
Log::Warn << "Cannot open file '" << filename << "' for loading. "
|
|
89
|
-
<< "Please check if the file exists." << std::endl;
|
|
90
|
-
|
|
91
|
-
else if (opts.Fatal() && !isLoading)
|
|
92
|
-
Log::Fatal << "Cannot open file '" << filename << "' for saving. "
|
|
93
|
-
<< "Please check if you have permissions for writing." << std::endl;
|
|
94
|
-
|
|
95
|
-
else if (!opts.Fatal() && !isLoading)
|
|
96
|
-
Log::Warn << "Cannot open file '" << filename << "' for saving. "
|
|
97
|
-
<< "Please check if you have permissions for writing." << std::endl;
|
|
98
|
-
|
|
99
|
-
return false;
|
|
100
|
-
}
|
|
101
|
-
return true;
|
|
102
|
-
}
|
|
103
|
-
|
|
104
|
-
template<typename MatType, typename DataOptionsType>
|
|
105
|
-
bool DetectFileType(const std::string& filename,
|
|
106
|
-
DataOptionsType& opts,
|
|
107
|
-
bool isLoading,
|
|
108
|
-
std::fstream* stream = nullptr)
|
|
109
|
-
{
|
|
110
|
-
// Add if for ModelOptions in a couple of stages
|
|
111
|
-
if (opts.Format() == FileType::AutoDetect)
|
|
112
|
-
{
|
|
113
|
-
if (isLoading)
|
|
114
|
-
// Attempt to auto-detect the type from the given file.
|
|
115
|
-
opts.Format() = AutoDetect(*stream, filename);
|
|
116
|
-
else
|
|
117
|
-
DetectFromExtension<MatType>(filename, opts);
|
|
118
|
-
// Provide error if we don't know the type.
|
|
119
|
-
if (opts.Format() == FileType::FileTypeUnknown)
|
|
120
|
-
{
|
|
121
|
-
if (opts.Fatal())
|
|
122
|
-
Log::Fatal << "Unable to detect type of '" << filename << "'; "
|
|
123
|
-
<< "Incorrect extension?" << std::endl;
|
|
124
|
-
else
|
|
125
|
-
Log::Warn << "Unable to detect type of '" << filename << "'; "
|
|
126
|
-
<< "Incorrect extension?" << std::endl;
|
|
127
|
-
|
|
128
|
-
return false;
|
|
129
|
-
}
|
|
130
|
-
}
|
|
131
|
-
return true;
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
template<typename MatType, typename DataOptionsType>
|
|
135
|
-
bool SaveMatrix(const MatType& matrix,
|
|
136
|
-
const DataOptionsType& opts,
|
|
137
|
-
const std::string& filename,
|
|
138
|
-
std::fstream& stream)
|
|
139
|
-
{
|
|
140
|
-
bool success = false;
|
|
141
|
-
if (opts.Format() == FileType::HDF5Binary)
|
|
142
|
-
{
|
|
143
|
-
#ifdef ARMA_USE_HDF5
|
|
144
|
-
// We can't save with streams for HDF5.
|
|
145
|
-
success = matrix.save(filename, ToArmaFileType(opts.Format()))
|
|
146
|
-
#endif
|
|
147
|
-
}
|
|
148
|
-
else
|
|
149
|
-
{
|
|
150
|
-
success = matrix.save(stream, ToArmaFileType(opts.Format()));
|
|
151
|
-
}
|
|
152
|
-
return success;
|
|
153
|
-
}
|
|
154
|
-
|
|
155
|
-
} //namespace data
|
|
156
|
-
} //namespace mlpack
|
|
157
|
-
|
|
158
|
-
#endif
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/utilities.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
* @author Gopi Tatiraju
|
|
6
|
+
*
|
|
7
|
+
* Utilities functions that can be used during loading and saving the data..
|
|
8
|
+
*
|
|
9
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
10
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
11
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
12
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
13
|
+
*/
|
|
14
|
+
#ifndef MLPACK_CORE_DATA_UTILITIES_HPP
|
|
15
|
+
#define MLPACK_CORE_DATA_UTILITIES_HPP
|
|
16
|
+
|
|
17
|
+
#include <mlpack/prereqs.hpp>
|
|
18
|
+
|
|
19
|
+
#include "detect_file_type.hpp"
|
|
20
|
+
|
|
21
|
+
namespace mlpack {
|
|
22
|
+
namespace data {
|
|
23
|
+
|
|
24
|
+
namespace details {
|
|
25
|
+
|
|
26
|
+
template<typename Tokenizer>
|
|
27
|
+
std::vector<std::string> ToTokens(Tokenizer& lineTok)
|
|
28
|
+
{
|
|
29
|
+
std::vector<std::string> tokens;
|
|
30
|
+
std::transform(std::begin(lineTok), std::end(lineTok),
|
|
31
|
+
std::back_inserter(tokens),
|
|
32
|
+
[&tokens](std::string const &str)
|
|
33
|
+
{
|
|
34
|
+
std::string trimmedToken(str);
|
|
35
|
+
Trim(trimmedToken);
|
|
36
|
+
return std::move(trimmedToken);
|
|
37
|
+
});
|
|
38
|
+
|
|
39
|
+
return tokens;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
inline
|
|
43
|
+
void TransposeTokens(std::vector<std::vector<std::string>> const &input,
|
|
44
|
+
std::vector<std::string>& output,
|
|
45
|
+
size_t index)
|
|
46
|
+
{
|
|
47
|
+
output.clear();
|
|
48
|
+
for (size_t i = 0; i != input.size(); ++i)
|
|
49
|
+
{
|
|
50
|
+
output.emplace_back(input[i][index]);
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
} // namespace details
|
|
54
|
+
|
|
55
|
+
template<typename DataOptionsType>
|
|
56
|
+
bool OpenFile(const std::string& filename,
|
|
57
|
+
DataOptionsType& opts,
|
|
58
|
+
bool isLoading,
|
|
59
|
+
std::fstream& stream)
|
|
60
|
+
{
|
|
61
|
+
if (isLoading)
|
|
62
|
+
{
|
|
63
|
+
#ifdef _WIN32 // Always open in binary mode on Windows.
|
|
64
|
+
stream.open(filename.c_str(), std::fstream::in
|
|
65
|
+
| std::fstream::binary);
|
|
66
|
+
#else
|
|
67
|
+
stream.open(filename.c_str(), std::fstream::in);
|
|
68
|
+
#endif
|
|
69
|
+
}
|
|
70
|
+
// Add here and else if for ModelOptions in a couple of stages.
|
|
71
|
+
else
|
|
72
|
+
{
|
|
73
|
+
#ifdef _WIN32 // Always open in binary mode on Windows.
|
|
74
|
+
stream.open(filename.c_str(), std::fstream::out
|
|
75
|
+
| std::fstream::binary);
|
|
76
|
+
#else
|
|
77
|
+
stream.open(filename.c_str(), std::fstream::out);
|
|
78
|
+
#endif
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
if (!stream.is_open())
|
|
82
|
+
{
|
|
83
|
+
if (opts.Fatal() && isLoading)
|
|
84
|
+
Log::Fatal << "Cannot open file '" << filename << "' for loading. "
|
|
85
|
+
<< "Please check if the file exists." << std::endl;
|
|
86
|
+
|
|
87
|
+
else if (!opts.Fatal() && isLoading)
|
|
88
|
+
Log::Warn << "Cannot open file '" << filename << "' for loading. "
|
|
89
|
+
<< "Please check if the file exists." << std::endl;
|
|
90
|
+
|
|
91
|
+
else if (opts.Fatal() && !isLoading)
|
|
92
|
+
Log::Fatal << "Cannot open file '" << filename << "' for saving. "
|
|
93
|
+
<< "Please check if you have permissions for writing." << std::endl;
|
|
94
|
+
|
|
95
|
+
else if (!opts.Fatal() && !isLoading)
|
|
96
|
+
Log::Warn << "Cannot open file '" << filename << "' for saving. "
|
|
97
|
+
<< "Please check if you have permissions for writing." << std::endl;
|
|
98
|
+
|
|
99
|
+
return false;
|
|
100
|
+
}
|
|
101
|
+
return true;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
template<typename MatType, typename DataOptionsType>
|
|
105
|
+
bool DetectFileType(const std::string& filename,
|
|
106
|
+
DataOptionsType& opts,
|
|
107
|
+
bool isLoading,
|
|
108
|
+
std::fstream* stream = nullptr)
|
|
109
|
+
{
|
|
110
|
+
// Add if for ModelOptions in a couple of stages
|
|
111
|
+
if (opts.Format() == FileType::AutoDetect)
|
|
112
|
+
{
|
|
113
|
+
if (isLoading)
|
|
114
|
+
// Attempt to auto-detect the type from the given file.
|
|
115
|
+
opts.Format() = AutoDetect(*stream, filename);
|
|
116
|
+
else
|
|
117
|
+
DetectFromExtension<MatType>(filename, opts);
|
|
118
|
+
// Provide error if we don't know the type.
|
|
119
|
+
if (opts.Format() == FileType::FileTypeUnknown)
|
|
120
|
+
{
|
|
121
|
+
if (opts.Fatal())
|
|
122
|
+
Log::Fatal << "Unable to detect type of '" << filename << "'; "
|
|
123
|
+
<< "Incorrect extension?" << std::endl;
|
|
124
|
+
else
|
|
125
|
+
Log::Warn << "Unable to detect type of '" << filename << "'; "
|
|
126
|
+
<< "Incorrect extension?" << std::endl;
|
|
127
|
+
|
|
128
|
+
return false;
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
return true;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
template<typename MatType, typename DataOptionsType>
|
|
135
|
+
bool SaveMatrix(const MatType& matrix,
|
|
136
|
+
const DataOptionsType& opts,
|
|
137
|
+
const std::string& filename,
|
|
138
|
+
std::fstream& stream)
|
|
139
|
+
{
|
|
140
|
+
bool success = false;
|
|
141
|
+
if (opts.Format() == FileType::HDF5Binary)
|
|
142
|
+
{
|
|
143
|
+
#ifdef ARMA_USE_HDF5
|
|
144
|
+
// We can't save with streams for HDF5.
|
|
145
|
+
success = matrix.save(filename, ToArmaFileType(opts.Format()));
|
|
146
|
+
#endif
|
|
147
|
+
}
|
|
148
|
+
else
|
|
149
|
+
{
|
|
150
|
+
success = matrix.save(stream, ToArmaFileType(opts.Format()));
|
|
151
|
+
}
|
|
152
|
+
return success;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
} //namespace data
|
|
156
|
+
} //namespace mlpack
|
|
157
|
+
|
|
158
|
+
#endif
|
|
@@ -31,11 +31,10 @@ inline arma::Mat<eT> ColumnCovariance(const arma::Mat<eT>& x,
|
|
|
31
31
|
|
|
32
32
|
if (x.n_elem > 0)
|
|
33
33
|
{
|
|
34
|
-
const arma::Mat<eT
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
false);
|
|
34
|
+
const arma::Mat<eT> xAlias;
|
|
35
|
+
MakeAlias(const_cast<arma::Mat<eT>&>(xAlias), x,
|
|
36
|
+
(x.n_cols == 1) ? x.n_cols : x.n_rows,
|
|
37
|
+
(x.n_cols == 1) ? x.n_rows : x.n_cols, 0, false);
|
|
39
38
|
|
|
40
39
|
const size_t n = xAlias.n_cols;
|
|
41
40
|
const eT normVal = (normType == 0) ? ((n > 1) ? eT(n - 1) : eT(1)) : eT(n);
|
|
@@ -26,7 +26,9 @@ void MakeAlias(OutVecType& v,
|
|
|
26
26
|
const size_t offset = 0,
|
|
27
27
|
const bool strict = true,
|
|
28
28
|
const typename std::enable_if_t<
|
|
29
|
-
IsVector<OutVecType>::value
|
|
29
|
+
IsVector<OutVecType>::value &&
|
|
30
|
+
IsArma<InVecType>::value &&
|
|
31
|
+
IsArma<OutVecType>::value>* = 0)
|
|
30
32
|
{
|
|
31
33
|
// We use placement new to reinitialize the object, since the copy and move
|
|
32
34
|
// assignment operators in Armadillo will end up copying memory instead of
|
|
@@ -49,7 +51,9 @@ void MakeAlias(OutMatType& m,
|
|
|
49
51
|
const size_t offset = 0,
|
|
50
52
|
const bool strict = true,
|
|
51
53
|
const typename std::enable_if_t<
|
|
52
|
-
IsMatrix<OutMatType>::value
|
|
54
|
+
IsMatrix<OutMatType>::value &&
|
|
55
|
+
IsArma<InMatType>::value &&
|
|
56
|
+
IsArma<OutMatType>::value>* = 0)
|
|
53
57
|
{
|
|
54
58
|
// We use placement new to reinitialize the object, since the copy and move
|
|
55
59
|
// assignment operators in Armadillo will end up copying memory instead of
|
|
@@ -72,7 +76,10 @@ void MakeAlias(OutCubeType& c,
|
|
|
72
76
|
const size_t numSlices,
|
|
73
77
|
const size_t offset = 0,
|
|
74
78
|
const bool strict = true,
|
|
75
|
-
const typename std::enable_if_t<
|
|
79
|
+
const typename std::enable_if_t<
|
|
80
|
+
IsCube<OutCubeType>::value &&
|
|
81
|
+
IsArma<InCubeType>::value &&
|
|
82
|
+
IsArma<OutCubeType>::value>* = 0)
|
|
76
83
|
{
|
|
77
84
|
// We use placement new to reinitialize the object, since the copy and move
|
|
78
85
|
// assignment operators in Armadillo will end up copying memory instead of
|
|
@@ -119,6 +126,94 @@ void ClearAlias(arma::SpMat<ElemType>& /* mat */)
|
|
|
119
126
|
// We cannot make aliases of sparse matrices, so, nothing to do.
|
|
120
127
|
}
|
|
121
128
|
|
|
129
|
+
#if defined(MLPACK_HAS_COOT)
|
|
130
|
+
|
|
131
|
+
/**
|
|
132
|
+
* Reconstruct `v` as an alias around the memory `newMem`, with size `numRows` x
|
|
133
|
+
* `numCols`.
|
|
134
|
+
*/
|
|
135
|
+
template<typename InVecType, typename OutVecType>
|
|
136
|
+
void MakeAlias(OutVecType& v,
|
|
137
|
+
const InVecType& oldVec,
|
|
138
|
+
const size_t numElems,
|
|
139
|
+
const size_t offset = 0,
|
|
140
|
+
const bool strict = true,
|
|
141
|
+
const typename std::enable_if_t<
|
|
142
|
+
IsVector<OutVecType>::value &&
|
|
143
|
+
IsCoot<InVecType>::value &&
|
|
144
|
+
IsCoot<OutVecType>::value>* = 0)
|
|
145
|
+
{
|
|
146
|
+
// We use placement new to reinitialize the object, since the copy and move
|
|
147
|
+
// assignment operators in Bandicoot will end up copying memory instead of
|
|
148
|
+
// making an alias.
|
|
149
|
+
coot::dev_mem_t<InVecType::elem_type> newMem = oldVec.get_dev_mem() + offset;
|
|
150
|
+
v.~OutVecType();
|
|
151
|
+
new (&v) OutVecType(newMem, numElems, false, strict);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
/**
|
|
155
|
+
* Reconstruct `m` as an alias around the memory `newMem`, with size `numRows` x
|
|
156
|
+
* `numCols`.
|
|
157
|
+
*/
|
|
158
|
+
template<typename InMatType, typename OutMatType>
|
|
159
|
+
void MakeAlias(OutMatType& m,
|
|
160
|
+
const InMatType& oldMat,
|
|
161
|
+
const size_t numRows,
|
|
162
|
+
const size_t numCols,
|
|
163
|
+
const size_t offset = 0,
|
|
164
|
+
const bool strict = true,
|
|
165
|
+
const typename std::enable_if_t<
|
|
166
|
+
IsMatrix<OutMatType>::value &&
|
|
167
|
+
IsCoot<InMatType>::value &&
|
|
168
|
+
IsCoot<OutMatType>::value>* = 0)
|
|
169
|
+
{
|
|
170
|
+
// We use placement new to reinitialize the object, since the copy and move
|
|
171
|
+
// assignment operators in Bandicoot will end up copying memory instead of
|
|
172
|
+
// making an alias.
|
|
173
|
+
coot::dev_mem_t<InMatType::elem_type> newMem = oldMat.get_dev_mem() + offset;
|
|
174
|
+
m.~OutMatType();
|
|
175
|
+
new (&m) OutMatType(newMem, numRows, numCols);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
/**
|
|
179
|
+
* Reconstruct `c` as an alias around the memory` newMem`, with size `numRows` x
|
|
180
|
+
* `numCols` x `numSlices`.
|
|
181
|
+
*/
|
|
182
|
+
template<typename InCubeType, typename OutCubeType>
|
|
183
|
+
void MakeAlias(OutCubeType& c,
|
|
184
|
+
const InCubeType& oldCube,
|
|
185
|
+
const size_t numRows,
|
|
186
|
+
const size_t numCols,
|
|
187
|
+
const size_t numSlices,
|
|
188
|
+
const size_t offset = 0,
|
|
189
|
+
const bool strict = true,
|
|
190
|
+
const typename std::enable_if_t<
|
|
191
|
+
IsCube<OutCubeType>::value &&
|
|
192
|
+
IsCoot<InCubeType>::value &&
|
|
193
|
+
IsCoot<OutCubeType>::value>* = 0)
|
|
194
|
+
{
|
|
195
|
+
// We use placement new to reinitialize the object, since the copy and move
|
|
196
|
+
// assignment operators in Bandicoot will end up copying memory instead of
|
|
197
|
+
// making an alias.
|
|
198
|
+
coot::dev_mem_t<InCubeType::elem_type> newMem =
|
|
199
|
+
oldCube.get_dev_mem() + offset;
|
|
200
|
+
c.~OutCubeType();
|
|
201
|
+
new (&c) OutCubeType(newMem, numRows, numCols, numSlices);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
/**
|
|
205
|
+
* Clear an alias so that no data is overwritten. This resets the matrix if it
|
|
206
|
+
* is an alias (and does nothing otherwise).
|
|
207
|
+
*/
|
|
208
|
+
template<typename ElemType>
|
|
209
|
+
void ClearAlias(coot::Mat<ElemType>& mat)
|
|
210
|
+
{
|
|
211
|
+
if (mat.mem_state >= 1)
|
|
212
|
+
mat.reset();
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
#endif // defined(MLPACK_HAS_COOT)
|
|
216
|
+
|
|
122
217
|
} // namespace mlpack
|
|
123
218
|
|
|
124
219
|
#endif
|
|
@@ -239,6 +239,16 @@ struct GetCubeType<arma::Mat<eT>>
|
|
|
239
239
|
using type = arma::Cube<eT>;
|
|
240
240
|
};
|
|
241
241
|
|
|
242
|
+
#if defined(MLPACK_HAS_COOT)
|
|
243
|
+
|
|
244
|
+
template<typename eT>
|
|
245
|
+
struct GetCubeType<coot::Mat<eT>>
|
|
246
|
+
{
|
|
247
|
+
using type = coot::Cube<eT>;
|
|
248
|
+
};
|
|
249
|
+
|
|
250
|
+
#endif
|
|
251
|
+
|
|
242
252
|
// Get the sparse matrix type corresponding to a given MatType.
|
|
243
253
|
|
|
244
254
|
template<typename MatType>
|
|
@@ -346,18 +356,25 @@ struct IsDense<arma::Mat<eT>>
|
|
|
346
356
|
constexpr static bool value = true;
|
|
347
357
|
};
|
|
348
358
|
|
|
359
|
+
// Get whether or not the given type is any non-field Armadillo type
|
|
360
|
+
// This includes sparse, dense, and cube types
|
|
349
361
|
template<typename T>
|
|
350
362
|
struct IsArma
|
|
351
363
|
{
|
|
352
|
-
constexpr static bool value = arma::is_arma_type<T>::value
|
|
364
|
+
constexpr static bool value = arma::is_arma_type<T>::value ||
|
|
365
|
+
arma::is_arma_cube_type<T>::value ||
|
|
366
|
+
arma::is_arma_sparse_type<T>::value;
|
|
353
367
|
};
|
|
354
368
|
|
|
355
369
|
#if defined(MLPACK_HAS_COOT)
|
|
356
370
|
|
|
371
|
+
// Get whether or not the given type is any Bandicoot type
|
|
372
|
+
// This includes dense and cube types
|
|
357
373
|
template<typename T>
|
|
358
374
|
struct IsCoot
|
|
359
375
|
{
|
|
360
|
-
constexpr static bool value = coot::is_coot_type<T>::value
|
|
376
|
+
constexpr static bool value = coot::is_coot_type<T>::value ||
|
|
377
|
+
coot::is_coot_cube_type<T>::value;
|
|
361
378
|
};
|
|
362
379
|
|
|
363
380
|
#else
|
|
@@ -1 +1 @@
|
|
|
1
|
-
return "mlpack git-
|
|
1
|
+
return "mlpack git-0fdccbfb21";
|
|
@@ -98,6 +98,28 @@ struct MethodFormDetector<Class, MethodForm, 7>
|
|
|
98
98
|
void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7>);
|
|
99
99
|
};
|
|
100
100
|
|
|
101
|
+
template<typename Class, template<typename...> class MethodForm>
|
|
102
|
+
struct MethodFormDetector<Class, MethodForm, 8>
|
|
103
|
+
{
|
|
104
|
+
template<class T1, class T2, class T3, class T4, class T5, class T6, class T7,
|
|
105
|
+
class T8>
|
|
106
|
+
void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7, T8>);
|
|
107
|
+
};
|
|
108
|
+
template<typename Class, template<typename...> class MethodForm>
|
|
109
|
+
struct MethodFormDetector<Class, MethodForm, 9>
|
|
110
|
+
{
|
|
111
|
+
template<class T1, class T2, class T3, class T4, class T5, class T6, class T7,
|
|
112
|
+
class T8, class T9>
|
|
113
|
+
void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7, T8, T9>);
|
|
114
|
+
};
|
|
115
|
+
template<typename Class, template<typename...> class MethodForm>
|
|
116
|
+
struct MethodFormDetector<Class, MethodForm, 10>
|
|
117
|
+
{
|
|
118
|
+
template<class T1, class T2, class T3, class T4, class T5, class T6, class T7,
|
|
119
|
+
class T8, class T9, class T10>
|
|
120
|
+
void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>);
|
|
121
|
+
};
|
|
122
|
+
|
|
101
123
|
//! Utility struct for checking signatures.
|
|
102
124
|
template<typename U, U> struct SigCheck : std::true_type {};
|
|
103
125
|
|
|
@@ -237,7 +259,7 @@ struct NAME \
|
|
|
237
259
|
* we can check whether the class A has a Train method of the specified form:
|
|
238
260
|
*
|
|
239
261
|
* HAS_METHOD_FORM(Train, HasTrain);
|
|
240
|
-
* static_assert(HasTrain<A,
|
|
262
|
+
* static_assert(HasTrain<A, TrainForm>::value, "value should be true");
|
|
241
263
|
*
|
|
242
264
|
* The class generated by this will also return true values if the given class
|
|
243
265
|
* has a method that also has extra parameters.
|
|
@@ -246,7 +268,7 @@ struct NAME \
|
|
|
246
268
|
* @param NAME The name of the struct to construct.
|
|
247
269
|
*/
|
|
248
270
|
#define HAS_METHOD_FORM(METHOD, NAME) \
|
|
249
|
-
HAS_METHOD_FORM_BASE(SINGLE_ARG(METHOD), SINGLE_ARG(NAME),
|
|
271
|
+
HAS_METHOD_FORM_BASE(SINGLE_ARG(METHOD), SINGLE_ARG(NAME), 10)
|
|
250
272
|
|
|
251
273
|
/**
|
|
252
274
|
* HAS_EXACT_METHOD_FORM generates a template that allows a compile time check
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
// with higher number than the most recent release.
|
|
19
19
|
#define MLPACK_VERSION_MAJOR 4
|
|
20
20
|
#define MLPACK_VERSION_MINOR 6
|
|
21
|
-
#define MLPACK_VERSION_PATCH
|
|
21
|
+
#define MLPACK_VERSION_PATCH 2
|
|
22
22
|
|
|
23
23
|
// The name of the version (for use by --version).
|
|
24
24
|
namespace mlpack {
|
|
@@ -48,9 +48,9 @@ class NetworkInitialization
|
|
|
48
48
|
* @param parameter The network parameter.
|
|
49
49
|
* @param parameterOffset Offset for network paramater, default 0.
|
|
50
50
|
*/
|
|
51
|
-
template <typename
|
|
52
|
-
void Initialize(const std::vector<Layer<
|
|
53
|
-
|
|
51
|
+
template <typename MatType>
|
|
52
|
+
void Initialize(const std::vector<Layer<MatType>*>& network,
|
|
53
|
+
MatType& parameters,
|
|
54
54
|
size_t parameterOffset = 0)
|
|
55
55
|
{
|
|
56
56
|
// Determine the total number of parameters/weights of the given network.
|
|
@@ -71,8 +71,8 @@ class NetworkInitialization
|
|
|
71
71
|
// Initialize the layer with the specified parameter/weight
|
|
72
72
|
// initialization rule.
|
|
73
73
|
const size_t weight = network[i]->WeightSize();
|
|
74
|
-
|
|
75
|
-
|
|
74
|
+
MatType tmp;
|
|
75
|
+
MakeAlias(tmp, parameters, weight, 1, offset, false);
|
|
76
76
|
initializeRule.Initialize(tmp, tmp.n_elem, 1);
|
|
77
77
|
|
|
78
78
|
// Increase the parameter/weight offset for the next layer.
|
|
@@ -53,6 +53,7 @@ template <typename MatType = arma::mat>
|
|
|
53
53
|
class BatchNormType : public Layer<MatType>
|
|
54
54
|
{
|
|
55
55
|
public:
|
|
56
|
+
using CubeType = typename GetCubeType<MatType>::type;
|
|
56
57
|
/**
|
|
57
58
|
* Create the BatchNorm object.
|
|
58
59
|
*
|
|
@@ -266,10 +267,10 @@ class BatchNormType : public Layer<MatType>
|
|
|
266
267
|
MatType runningVariance;
|
|
267
268
|
|
|
268
269
|
//! Locally-stored normalized input.
|
|
269
|
-
|
|
270
|
+
CubeType normalized;
|
|
270
271
|
|
|
271
272
|
//! Locally-stored zero mean input.
|
|
272
|
-
|
|
273
|
+
CubeType inputMean;
|
|
273
274
|
}; // class BatchNorm
|
|
274
275
|
|
|
275
276
|
// Convenience typedefs.
|