mlpack 4.6.1__cp313-cp313-win_amd64.whl → 4.6.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. mlpack/__init__.py +1 -1
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +21 -12
  24. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +49 -39
  25. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +9 -46
  26. mlpack/include/mlpack/core/data/save_impl.hpp +315 -315
  27. mlpack/include/mlpack/core/data/utilities.hpp +158 -158
  28. mlpack/include/mlpack/core/math/ccov.hpp +1 -0
  29. mlpack/include/mlpack/core/math/ccov_impl.hpp +4 -5
  30. mlpack/include/mlpack/core/math/make_alias.hpp +98 -3
  31. mlpack/include/mlpack/core/util/arma_traits.hpp +19 -2
  32. mlpack/include/mlpack/core/util/gitversion.hpp +1 -1
  33. mlpack/include/mlpack/core/util/sfinae_utility.hpp +24 -2
  34. mlpack/include/mlpack/core/util/version.hpp +1 -1
  35. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -2
  36. mlpack/include/mlpack/methods/ann/init_rules/network_init.hpp +5 -5
  37. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +3 -2
  38. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +19 -20
  39. mlpack/include/mlpack/methods/ann/layer/concat.hpp +1 -0
  40. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +6 -7
  41. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +3 -3
  42. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +3 -3
  43. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +1 -0
  44. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +11 -14
  45. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +5 -4
  46. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +15 -14
  47. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +3 -2
  48. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +14 -15
  49. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +6 -5
  50. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +24 -25
  51. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +1 -0
  52. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +4 -4
  53. mlpack/include/mlpack/methods/ann/layer/padding.hpp +1 -0
  54. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +12 -13
  55. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +3 -2
  56. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +5 -4
  57. mlpack/include/mlpack/methods/ann/rnn.hpp +19 -18
  58. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +15 -15
  59. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_impl.hpp +3 -8
  60. mlpack/include/mlpack/methods/decision_tree/fitness_functions/gini_gain.hpp +5 -8
  61. mlpack/include/mlpack/methods/decision_tree/fitness_functions/information_gain.hpp +5 -8
  62. mlpack/include/mlpack/methods/gmm/diagonal_gmm_impl.hpp +2 -1
  63. mlpack/include/mlpack/methods/gmm/eigenvalue_ratio_constraint.hpp +3 -3
  64. mlpack/include/mlpack/methods/gmm/gmm_impl.hpp +2 -1
  65. mlpack/include/mlpack/methods/hmm/hmm_impl.hpp +10 -5
  66. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +57 -37
  67. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +69 -59
  68. mlpack/kde.cp313-win_amd64.pyd +0 -0
  69. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  70. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  71. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  72. mlpack/knn.cp313-win_amd64.pyd +0 -0
  73. mlpack/krann.cp313-win_amd64.pyd +0 -0
  74. mlpack/lars.cp313-win_amd64.pyd +0 -0
  75. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  76. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  77. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  78. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  79. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  80. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  81. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  82. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  83. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  84. mlpack/nca.cp313-win_amd64.pyd +0 -0
  85. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  86. mlpack/pca.cp313-win_amd64.pyd +0 -0
  87. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  88. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  89. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  90. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  91. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  92. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  93. mlpack/radical.cp313-win_amd64.pyd +0 -0
  94. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  95. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  96. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  97. mlpack-4.6.2.dist-info/DELVEWHEEL +2 -0
  98. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/METADATA +2 -2
  99. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/RECORD +101 -101
  100. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/WHEEL +1 -1
  101. mlpack-4.6.1.dist-info/DELVEWHEEL +0 -2
  102. {mlpack-4.6.1.dist-info → mlpack-4.6.2.dist-info}/top_level.txt +0 -0
@@ -1,158 +1,158 @@
1
- /**
2
- * @file core/data/utilities.hpp
3
- * @author Ryan Curtin
4
- * @author Omar Shrit
5
- * @author Gopi Tatiraju
6
- *
7
- * Utilities functions that can be used during loading and saving the data..
8
- *
9
- * mlpack is free software; you may redistribute it and/or modify it under the
10
- * terms of the 3-clause BSD license. You should have received a copy of the
11
- * 3-clause BSD license along with mlpack. If not, see
12
- * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13
- */
14
- #ifndef MLPACK_CORE_DATA_UTILITIES_HPP
15
- #define MLPACK_CORE_DATA_UTILITIES_HPP
16
-
17
- #include <mlpack/prereqs.hpp>
18
-
19
- #include "detect_file_type.hpp"
20
-
21
- namespace mlpack {
22
- namespace data {
23
-
24
- namespace details {
25
-
26
- template<typename Tokenizer>
27
- std::vector<std::string> ToTokens(Tokenizer& lineTok)
28
- {
29
- std::vector<std::string> tokens;
30
- std::transform(std::begin(lineTok), std::end(lineTok),
31
- std::back_inserter(tokens),
32
- [&tokens](std::string const &str)
33
- {
34
- std::string trimmedToken(str);
35
- Trim(trimmedToken);
36
- return std::move(trimmedToken);
37
- });
38
-
39
- return tokens;
40
- }
41
-
42
- inline
43
- void TransposeTokens(std::vector<std::vector<std::string>> const &input,
44
- std::vector<std::string>& output,
45
- size_t index)
46
- {
47
- output.clear();
48
- for (size_t i = 0; i != input.size(); ++i)
49
- {
50
- output.emplace_back(input[i][index]);
51
- }
52
- }
53
- } // namespace details
54
-
55
- template<typename DataOptionsType>
56
- bool OpenFile(const std::string& filename,
57
- DataOptionsType& opts,
58
- bool isLoading,
59
- std::fstream& stream)
60
- {
61
- if (isLoading)
62
- {
63
- #ifdef _WIN32 // Always open in binary mode on Windows.
64
- stream.open(filename.c_str(), std::fstream::in
65
- | std::fstream::binary);
66
- #else
67
- stream.open(filename.c_str(), std::fstream::in);
68
- #endif
69
- }
70
- // Add here and else if for ModelOptions in a couple of stages.
71
- else
72
- {
73
- #ifdef _WIN32 // Always open in binary mode on Windows.
74
- stream.open(filename.c_str(), std::fstream::out
75
- | std::fstream::binary);
76
- #else
77
- stream.open(filename.c_str(), std::fstream::out);
78
- #endif
79
- }
80
-
81
- if (!stream.is_open())
82
- {
83
- if (opts.Fatal() && isLoading)
84
- Log::Fatal << "Cannot open file '" << filename << "' for loading. "
85
- << "Please check if the file exists." << std::endl;
86
-
87
- else if (!opts.Fatal() && isLoading)
88
- Log::Warn << "Cannot open file '" << filename << "' for loading. "
89
- << "Please check if the file exists." << std::endl;
90
-
91
- else if (opts.Fatal() && !isLoading)
92
- Log::Fatal << "Cannot open file '" << filename << "' for saving. "
93
- << "Please check if you have permissions for writing." << std::endl;
94
-
95
- else if (!opts.Fatal() && !isLoading)
96
- Log::Warn << "Cannot open file '" << filename << "' for saving. "
97
- << "Please check if you have permissions for writing." << std::endl;
98
-
99
- return false;
100
- }
101
- return true;
102
- }
103
-
104
- template<typename MatType, typename DataOptionsType>
105
- bool DetectFileType(const std::string& filename,
106
- DataOptionsType& opts,
107
- bool isLoading,
108
- std::fstream* stream = nullptr)
109
- {
110
- // Add if for ModelOptions in a couple of stages
111
- if (opts.Format() == FileType::AutoDetect)
112
- {
113
- if (isLoading)
114
- // Attempt to auto-detect the type from the given file.
115
- opts.Format() = AutoDetect(*stream, filename);
116
- else
117
- DetectFromExtension<MatType>(filename, opts);
118
- // Provide error if we don't know the type.
119
- if (opts.Format() == FileType::FileTypeUnknown)
120
- {
121
- if (opts.Fatal())
122
- Log::Fatal << "Unable to detect type of '" << filename << "'; "
123
- << "Incorrect extension?" << std::endl;
124
- else
125
- Log::Warn << "Unable to detect type of '" << filename << "'; "
126
- << "Incorrect extension?" << std::endl;
127
-
128
- return false;
129
- }
130
- }
131
- return true;
132
- }
133
-
134
- template<typename MatType, typename DataOptionsType>
135
- bool SaveMatrix(const MatType& matrix,
136
- const DataOptionsType& opts,
137
- const std::string& filename,
138
- std::fstream& stream)
139
- {
140
- bool success = false;
141
- if (opts.Format() == FileType::HDF5Binary)
142
- {
143
- #ifdef ARMA_USE_HDF5
144
- // We can't save with streams for HDF5.
145
- success = matrix.save(filename, ToArmaFileType(opts.Format()))
146
- #endif
147
- }
148
- else
149
- {
150
- success = matrix.save(stream, ToArmaFileType(opts.Format()));
151
- }
152
- return success;
153
- }
154
-
155
- } //namespace data
156
- } //namespace mlpack
157
-
158
- #endif
1
+ /**
2
+ * @file core/data/utilities.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ * @author Gopi Tatiraju
6
+ *
7
+ * Utilities functions that can be used during loading and saving the data..
8
+ *
9
+ * mlpack is free software; you may redistribute it and/or modify it under the
10
+ * terms of the 3-clause BSD license. You should have received a copy of the
11
+ * 3-clause BSD license along with mlpack. If not, see
12
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13
+ */
14
+ #ifndef MLPACK_CORE_DATA_UTILITIES_HPP
15
+ #define MLPACK_CORE_DATA_UTILITIES_HPP
16
+
17
+ #include <mlpack/prereqs.hpp>
18
+
19
+ #include "detect_file_type.hpp"
20
+
21
+ namespace mlpack {
22
+ namespace data {
23
+
24
+ namespace details {
25
+
26
+ template<typename Tokenizer>
27
+ std::vector<std::string> ToTokens(Tokenizer& lineTok)
28
+ {
29
+ std::vector<std::string> tokens;
30
+ std::transform(std::begin(lineTok), std::end(lineTok),
31
+ std::back_inserter(tokens),
32
+ [&tokens](std::string const &str)
33
+ {
34
+ std::string trimmedToken(str);
35
+ Trim(trimmedToken);
36
+ return std::move(trimmedToken);
37
+ });
38
+
39
+ return tokens;
40
+ }
41
+
42
+ inline
43
+ void TransposeTokens(std::vector<std::vector<std::string>> const &input,
44
+ std::vector<std::string>& output,
45
+ size_t index)
46
+ {
47
+ output.clear();
48
+ for (size_t i = 0; i != input.size(); ++i)
49
+ {
50
+ output.emplace_back(input[i][index]);
51
+ }
52
+ }
53
+ } // namespace details
54
+
55
+ template<typename DataOptionsType>
56
+ bool OpenFile(const std::string& filename,
57
+ DataOptionsType& opts,
58
+ bool isLoading,
59
+ std::fstream& stream)
60
+ {
61
+ if (isLoading)
62
+ {
63
+ #ifdef _WIN32 // Always open in binary mode on Windows.
64
+ stream.open(filename.c_str(), std::fstream::in
65
+ | std::fstream::binary);
66
+ #else
67
+ stream.open(filename.c_str(), std::fstream::in);
68
+ #endif
69
+ }
70
+ // Add here and else if for ModelOptions in a couple of stages.
71
+ else
72
+ {
73
+ #ifdef _WIN32 // Always open in binary mode on Windows.
74
+ stream.open(filename.c_str(), std::fstream::out
75
+ | std::fstream::binary);
76
+ #else
77
+ stream.open(filename.c_str(), std::fstream::out);
78
+ #endif
79
+ }
80
+
81
+ if (!stream.is_open())
82
+ {
83
+ if (opts.Fatal() && isLoading)
84
+ Log::Fatal << "Cannot open file '" << filename << "' for loading. "
85
+ << "Please check if the file exists." << std::endl;
86
+
87
+ else if (!opts.Fatal() && isLoading)
88
+ Log::Warn << "Cannot open file '" << filename << "' for loading. "
89
+ << "Please check if the file exists." << std::endl;
90
+
91
+ else if (opts.Fatal() && !isLoading)
92
+ Log::Fatal << "Cannot open file '" << filename << "' for saving. "
93
+ << "Please check if you have permissions for writing." << std::endl;
94
+
95
+ else if (!opts.Fatal() && !isLoading)
96
+ Log::Warn << "Cannot open file '" << filename << "' for saving. "
97
+ << "Please check if you have permissions for writing." << std::endl;
98
+
99
+ return false;
100
+ }
101
+ return true;
102
+ }
103
+
104
+ template<typename MatType, typename DataOptionsType>
105
+ bool DetectFileType(const std::string& filename,
106
+ DataOptionsType& opts,
107
+ bool isLoading,
108
+ std::fstream* stream = nullptr)
109
+ {
110
+ // Add if for ModelOptions in a couple of stages
111
+ if (opts.Format() == FileType::AutoDetect)
112
+ {
113
+ if (isLoading)
114
+ // Attempt to auto-detect the type from the given file.
115
+ opts.Format() = AutoDetect(*stream, filename);
116
+ else
117
+ DetectFromExtension<MatType>(filename, opts);
118
+ // Provide error if we don't know the type.
119
+ if (opts.Format() == FileType::FileTypeUnknown)
120
+ {
121
+ if (opts.Fatal())
122
+ Log::Fatal << "Unable to detect type of '" << filename << "'; "
123
+ << "Incorrect extension?" << std::endl;
124
+ else
125
+ Log::Warn << "Unable to detect type of '" << filename << "'; "
126
+ << "Incorrect extension?" << std::endl;
127
+
128
+ return false;
129
+ }
130
+ }
131
+ return true;
132
+ }
133
+
134
+ template<typename MatType, typename DataOptionsType>
135
+ bool SaveMatrix(const MatType& matrix,
136
+ const DataOptionsType& opts,
137
+ const std::string& filename,
138
+ std::fstream& stream)
139
+ {
140
+ bool success = false;
141
+ if (opts.Format() == FileType::HDF5Binary)
142
+ {
143
+ #ifdef ARMA_USE_HDF5
144
+ // We can't save with streams for HDF5.
145
+ success = matrix.save(filename, ToArmaFileType(opts.Format()));
146
+ #endif
147
+ }
148
+ else
149
+ {
150
+ success = matrix.save(stream, ToArmaFileType(opts.Format()));
151
+ }
152
+ return success;
153
+ }
154
+
155
+ } //namespace data
156
+ } //namespace mlpack
157
+
158
+ #endif
@@ -15,6 +15,7 @@
15
15
  #define MLPACK_CORE_MATH_CCOV_HPP
16
16
 
17
17
  #include <mlpack/prereqs.hpp>
18
+ #include <mlpack/core/math/make_alias.hpp>
18
19
 
19
20
  namespace mlpack {
20
21
 
@@ -31,11 +31,10 @@ inline arma::Mat<eT> ColumnCovariance(const arma::Mat<eT>& x,
31
31
 
32
32
  if (x.n_elem > 0)
33
33
  {
34
- const arma::Mat<eT>& xAlias = (x.n_cols == 1) ?
35
- arma::Mat<eT>(const_cast<eT*>(x.memptr()), x.n_cols, x.n_rows, false,
36
- false) :
37
- arma::Mat<eT>(const_cast<eT*>(x.memptr()), x.n_rows, x.n_cols, false,
38
- false);
34
+ const arma::Mat<eT> xAlias;
35
+ MakeAlias(const_cast<arma::Mat<eT>&>(xAlias), x,
36
+ (x.n_cols == 1) ? x.n_cols : x.n_rows,
37
+ (x.n_cols == 1) ? x.n_rows : x.n_cols, 0, false);
39
38
 
40
39
  const size_t n = xAlias.n_cols;
41
40
  const eT normVal = (normType == 0) ? ((n > 1) ? eT(n - 1) : eT(1)) : eT(n);
@@ -26,7 +26,9 @@ void MakeAlias(OutVecType& v,
26
26
  const size_t offset = 0,
27
27
  const bool strict = true,
28
28
  const typename std::enable_if_t<
29
- IsVector<OutVecType>::value>* = 0)
29
+ IsVector<OutVecType>::value &&
30
+ IsArma<InVecType>::value &&
31
+ IsArma<OutVecType>::value>* = 0)
30
32
  {
31
33
  // We use placement new to reinitialize the object, since the copy and move
32
34
  // assignment operators in Armadillo will end up copying memory instead of
@@ -49,7 +51,9 @@ void MakeAlias(OutMatType& m,
49
51
  const size_t offset = 0,
50
52
  const bool strict = true,
51
53
  const typename std::enable_if_t<
52
- IsMatrix<OutMatType>::value>* = 0)
54
+ IsMatrix<OutMatType>::value &&
55
+ IsArma<InMatType>::value &&
56
+ IsArma<OutMatType>::value>* = 0)
53
57
  {
54
58
  // We use placement new to reinitialize the object, since the copy and move
55
59
  // assignment operators in Armadillo will end up copying memory instead of
@@ -72,7 +76,10 @@ void MakeAlias(OutCubeType& c,
72
76
  const size_t numSlices,
73
77
  const size_t offset = 0,
74
78
  const bool strict = true,
75
- const typename std::enable_if_t<IsCube<OutCubeType>::value>* = 0)
79
+ const typename std::enable_if_t<
80
+ IsCube<OutCubeType>::value &&
81
+ IsArma<InCubeType>::value &&
82
+ IsArma<OutCubeType>::value>* = 0)
76
83
  {
77
84
  // We use placement new to reinitialize the object, since the copy and move
78
85
  // assignment operators in Armadillo will end up copying memory instead of
@@ -119,6 +126,94 @@ void ClearAlias(arma::SpMat<ElemType>& /* mat */)
119
126
  // We cannot make aliases of sparse matrices, so, nothing to do.
120
127
  }
121
128
 
129
+ #if defined(MLPACK_HAS_COOT)
130
+
131
+ /**
132
+ * Reconstruct `v` as an alias around the memory `newMem`, with size `numRows` x
133
+ * `numCols`.
134
+ */
135
+ template<typename InVecType, typename OutVecType>
136
+ void MakeAlias(OutVecType& v,
137
+ const InVecType& oldVec,
138
+ const size_t numElems,
139
+ const size_t offset = 0,
140
+ const bool strict = true,
141
+ const typename std::enable_if_t<
142
+ IsVector<OutVecType>::value &&
143
+ IsCoot<InVecType>::value &&
144
+ IsCoot<OutVecType>::value>* = 0)
145
+ {
146
+ // We use placement new to reinitialize the object, since the copy and move
147
+ // assignment operators in Bandicoot will end up copying memory instead of
148
+ // making an alias.
149
+ coot::dev_mem_t<InVecType::elem_type> newMem = oldVec.get_dev_mem() + offset;
150
+ v.~OutVecType();
151
+ new (&v) OutVecType(newMem, numElems, false, strict);
152
+ }
153
+
154
+ /**
155
+ * Reconstruct `m` as an alias around the memory `newMem`, with size `numRows` x
156
+ * `numCols`.
157
+ */
158
+ template<typename InMatType, typename OutMatType>
159
+ void MakeAlias(OutMatType& m,
160
+ const InMatType& oldMat,
161
+ const size_t numRows,
162
+ const size_t numCols,
163
+ const size_t offset = 0,
164
+ const bool strict = true,
165
+ const typename std::enable_if_t<
166
+ IsMatrix<OutMatType>::value &&
167
+ IsCoot<InMatType>::value &&
168
+ IsCoot<OutMatType>::value>* = 0)
169
+ {
170
+ // We use placement new to reinitialize the object, since the copy and move
171
+ // assignment operators in Bandicoot will end up copying memory instead of
172
+ // making an alias.
173
+ coot::dev_mem_t<InMatType::elem_type> newMem = oldMat.get_dev_mem() + offset;
174
+ m.~OutMatType();
175
+ new (&m) OutMatType(newMem, numRows, numCols);
176
+ }
177
+
178
+ /**
179
+ * Reconstruct `c` as an alias around the memory` newMem`, with size `numRows` x
180
+ * `numCols` x `numSlices`.
181
+ */
182
+ template<typename InCubeType, typename OutCubeType>
183
+ void MakeAlias(OutCubeType& c,
184
+ const InCubeType& oldCube,
185
+ const size_t numRows,
186
+ const size_t numCols,
187
+ const size_t numSlices,
188
+ const size_t offset = 0,
189
+ const bool strict = true,
190
+ const typename std::enable_if_t<
191
+ IsCube<OutCubeType>::value &&
192
+ IsCoot<InCubeType>::value &&
193
+ IsCoot<OutCubeType>::value>* = 0)
194
+ {
195
+ // We use placement new to reinitialize the object, since the copy and move
196
+ // assignment operators in Bandicoot will end up copying memory instead of
197
+ // making an alias.
198
+ coot::dev_mem_t<InCubeType::elem_type> newMem =
199
+ oldCube.get_dev_mem() + offset;
200
+ c.~OutCubeType();
201
+ new (&c) OutCubeType(newMem, numRows, numCols, numSlices);
202
+ }
203
+
204
+ /**
205
+ * Clear an alias so that no data is overwritten. This resets the matrix if it
206
+ * is an alias (and does nothing otherwise).
207
+ */
208
+ template<typename ElemType>
209
+ void ClearAlias(coot::Mat<ElemType>& mat)
210
+ {
211
+ if (mat.mem_state >= 1)
212
+ mat.reset();
213
+ }
214
+
215
+ #endif // defined(MLPACK_HAS_COOT)
216
+
122
217
  } // namespace mlpack
123
218
 
124
219
  #endif
@@ -239,6 +239,16 @@ struct GetCubeType<arma::Mat<eT>>
239
239
  using type = arma::Cube<eT>;
240
240
  };
241
241
 
242
+ #if defined(MLPACK_HAS_COOT)
243
+
244
+ template<typename eT>
245
+ struct GetCubeType<coot::Mat<eT>>
246
+ {
247
+ using type = coot::Cube<eT>;
248
+ };
249
+
250
+ #endif
251
+
242
252
  // Get the sparse matrix type corresponding to a given MatType.
243
253
 
244
254
  template<typename MatType>
@@ -346,18 +356,25 @@ struct IsDense<arma::Mat<eT>>
346
356
  constexpr static bool value = true;
347
357
  };
348
358
 
359
+ // Get whether or not the given type is any non-field Armadillo type
360
+ // This includes sparse, dense, and cube types
349
361
  template<typename T>
350
362
  struct IsArma
351
363
  {
352
- constexpr static bool value = arma::is_arma_type<T>::value;
364
+ constexpr static bool value = arma::is_arma_type<T>::value ||
365
+ arma::is_arma_cube_type<T>::value ||
366
+ arma::is_arma_sparse_type<T>::value;
353
367
  };
354
368
 
355
369
  #if defined(MLPACK_HAS_COOT)
356
370
 
371
+ // Get whether or not the given type is any Bandicoot type
372
+ // This includes dense and cube types
357
373
  template<typename T>
358
374
  struct IsCoot
359
375
  {
360
- constexpr static bool value = coot::is_coot_type<T>::value;
376
+ constexpr static bool value = coot::is_coot_type<T>::value ||
377
+ coot::is_coot_cube_type<T>::value;
361
378
  };
362
379
 
363
380
  #else
@@ -1 +1 @@
1
- return "mlpack git-0480311ab1";
1
+ return "mlpack git-0fdccbfb21";
@@ -98,6 +98,28 @@ struct MethodFormDetector<Class, MethodForm, 7>
98
98
  void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7>);
99
99
  };
100
100
 
101
+ template<typename Class, template<typename...> class MethodForm>
102
+ struct MethodFormDetector<Class, MethodForm, 8>
103
+ {
104
+ template<class T1, class T2, class T3, class T4, class T5, class T6, class T7,
105
+ class T8>
106
+ void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7, T8>);
107
+ };
108
+ template<typename Class, template<typename...> class MethodForm>
109
+ struct MethodFormDetector<Class, MethodForm, 9>
110
+ {
111
+ template<class T1, class T2, class T3, class T4, class T5, class T6, class T7,
112
+ class T8, class T9>
113
+ void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7, T8, T9>);
114
+ };
115
+ template<typename Class, template<typename...> class MethodForm>
116
+ struct MethodFormDetector<Class, MethodForm, 10>
117
+ {
118
+ template<class T1, class T2, class T3, class T4, class T5, class T6, class T7,
119
+ class T8, class T9, class T10>
120
+ void operator()(MethodForm<Class, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>);
121
+ };
122
+
101
123
  //! Utility struct for checking signatures.
102
124
  template<typename U, U> struct SigCheck : std::true_type {};
103
125
 
@@ -237,7 +259,7 @@ struct NAME \
237
259
  * we can check whether the class A has a Train method of the specified form:
238
260
  *
239
261
  * HAS_METHOD_FORM(Train, HasTrain);
240
- * static_assert(HasTrain<A, TrainFrom>::value, "value should be true");
262
+ * static_assert(HasTrain<A, TrainForm>::value, "value should be true");
241
263
  *
242
264
  * The class generated by this will also return true values if the given class
243
265
  * has a method that also has extra parameters.
@@ -246,7 +268,7 @@ struct NAME \
246
268
  * @param NAME The name of the struct to construct.
247
269
  */
248
270
  #define HAS_METHOD_FORM(METHOD, NAME) \
249
- HAS_METHOD_FORM_BASE(SINGLE_ARG(METHOD), SINGLE_ARG(NAME), 7)
271
+ HAS_METHOD_FORM_BASE(SINGLE_ARG(METHOD), SINGLE_ARG(NAME), 10)
250
272
 
251
273
  /**
252
274
  * HAS_EXACT_METHOD_FORM generates a template that allows a compile time check
@@ -18,7 +18,7 @@
18
18
  // with higher number than the most recent release.
19
19
  #define MLPACK_VERSION_MAJOR 4
20
20
  #define MLPACK_VERSION_MINOR 6
21
- #define MLPACK_VERSION_PATCH 1
21
+ #define MLPACK_VERSION_PATCH 2
22
22
 
23
23
  // The name of the version (for use by --version).
24
24
  namespace mlpack {
@@ -40,8 +40,7 @@ BernoulliDistribution<DataType>::BernoulliDistribution(
40
40
  }
41
41
  else
42
42
  {
43
- probability = arma::mat(logits.memptr(), logits.n_rows,
44
- logits.n_cols, false, false);
43
+ MakeAlias(probability, logits, logits.n_rows, logits.n_cols);
45
44
  }
46
45
  }
47
46
 
@@ -48,9 +48,9 @@ class NetworkInitialization
48
48
  * @param parameter The network parameter.
49
49
  * @param parameterOffset Offset for network paramater, default 0.
50
50
  */
51
- template <typename eT>
52
- void Initialize(const std::vector<Layer<arma::Mat<eT>>*>& network,
53
- arma::Mat<eT>& parameters,
51
+ template <typename MatType>
52
+ void Initialize(const std::vector<Layer<MatType>*>& network,
53
+ MatType& parameters,
54
54
  size_t parameterOffset = 0)
55
55
  {
56
56
  // Determine the total number of parameters/weights of the given network.
@@ -71,8 +71,8 @@ class NetworkInitialization
71
71
  // Initialize the layer with the specified parameter/weight
72
72
  // initialization rule.
73
73
  const size_t weight = network[i]->WeightSize();
74
- arma::Mat<eT> tmp = arma::Mat<eT>(parameters.memptr() + offset,
75
- weight, 1, false, false);
74
+ MatType tmp;
75
+ MakeAlias(tmp, parameters, weight, 1, offset, false);
76
76
  initializeRule.Initialize(tmp, tmp.n_elem, 1);
77
77
 
78
78
  // Increase the parameter/weight offset for the next layer.
@@ -53,6 +53,7 @@ template <typename MatType = arma::mat>
53
53
  class BatchNormType : public Layer<MatType>
54
54
  {
55
55
  public:
56
+ using CubeType = typename GetCubeType<MatType>::type;
56
57
  /**
57
58
  * Create the BatchNorm object.
58
59
  *
@@ -266,10 +267,10 @@ class BatchNormType : public Layer<MatType>
266
267
  MatType runningVariance;
267
268
 
268
269
  //! Locally-stored normalized input.
269
- arma::Cube<typename MatType::elem_type> normalized;
270
+ CubeType normalized;
270
271
 
271
272
  //! Locally-stored zero mean input.
272
- arma::Cube<typename MatType::elem_type> inputMean;
273
+ CubeType inputMean;
273
274
  }; // class BatchNorm
274
275
 
275
276
  // Convenience typedefs.