mlpack 4.6.0__cp310-cp310-win_amd64.whl → 4.6.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp310-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp310-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp310-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp310-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp310-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp310-win_amd64.pyd +0 -0
  8. mlpack/cf.cp310-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp310-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp310-win_amd64.pyd +0 -0
  11. mlpack/det.cp310-win_amd64.pyd +0 -0
  12. mlpack/emst.cp310-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp310-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp310-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp310-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp310-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp310-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp310-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp310-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp310-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp310-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp310-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/data/data.hpp +5 -1
  25. mlpack/include/mlpack/core/data/data_options.hpp +219 -0
  26. mlpack/include/mlpack/core/data/detect_file_type.hpp +6 -8
  27. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +21 -30
  28. mlpack/include/mlpack/core/data/load.hpp +41 -3
  29. mlpack/include/mlpack/core/data/load_arff.hpp +4 -3
  30. mlpack/include/mlpack/core/data/load_arff_impl.hpp +68 -20
  31. mlpack/include/mlpack/core/data/{load_csv.hpp → load_categorical.hpp} +44 -80
  32. mlpack/include/mlpack/core/data/{load_categorical_csv.hpp → load_categorical_impl.hpp} +86 -46
  33. mlpack/include/mlpack/core/data/load_impl.hpp +264 -289
  34. mlpack/include/mlpack/core/data/load_model_impl.hpp +2 -1
  35. mlpack/include/mlpack/core/data/load_numeric.hpp +130 -0
  36. mlpack/include/mlpack/core/data/load_vec_impl.hpp +14 -10
  37. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +3 -2
  38. mlpack/include/mlpack/core/data/matrix_options.hpp +172 -0
  39. mlpack/include/mlpack/core/data/save.hpp +32 -2
  40. mlpack/include/mlpack/core/data/save_impl.hpp +315 -346
  41. mlpack/include/mlpack/core/data/text_options.hpp +244 -0
  42. mlpack/include/mlpack/core/data/types.hpp +3 -4
  43. mlpack/include/mlpack/core/data/utilities.hpp +158 -0
  44. mlpack/include/mlpack/core/math/shuffle_data.hpp +68 -0
  45. mlpack/include/mlpack/core/metrics/bleu_impl.hpp +1 -1
  46. mlpack/include/mlpack/core/tree/binary_space_tree/traits.hpp +36 -178
  47. mlpack/include/mlpack/core/tree/space_split/hyperplane.hpp +20 -14
  48. mlpack/include/mlpack/core/tree/space_split/mean_space_split_impl.hpp +2 -2
  49. mlpack/include/mlpack/core/tree/space_split/midpoint_space_split_impl.hpp +1 -1
  50. mlpack/include/mlpack/core/tree/space_split/projection_vector.hpp +6 -5
  51. mlpack/include/mlpack/core/tree/space_split/space_split.hpp +4 -4
  52. mlpack/include/mlpack/core/tree/space_split/space_split_impl.hpp +18 -12
  53. mlpack/include/mlpack/core/tree/spill_tree/is_spill_tree.hpp +1 -1
  54. mlpack/include/mlpack/core/tree/spill_tree/spill_dual_tree_traverser.hpp +2 -1
  55. mlpack/include/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp +4 -2
  56. mlpack/include/mlpack/core/tree/spill_tree/spill_single_tree_traverser.hpp +2 -1
  57. mlpack/include/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp +4 -2
  58. mlpack/include/mlpack/core/tree/spill_tree/spill_tree.hpp +13 -16
  59. mlpack/include/mlpack/core/tree/spill_tree/spill_tree_impl.hpp +78 -51
  60. mlpack/include/mlpack/core/tree/spill_tree/traits.hpp +2 -1
  61. mlpack/include/mlpack/core/tree/spill_tree/typedef.hpp +12 -4
  62. mlpack/include/mlpack/core/util/arma_traits.hpp +48 -0
  63. mlpack/include/mlpack/core/util/gitversion.hpp +1 -1
  64. mlpack/include/mlpack/core/util/version.hpp +1 -1
  65. mlpack/include/mlpack/methods/CMakeLists.txt +96 -96
  66. mlpack/include/mlpack/methods/amf/init_rules/no_init.hpp +1 -1
  67. mlpack/include/mlpack/methods/amf/update_rules/svd_batch_learning.hpp +0 -2
  68. mlpack/include/mlpack/methods/ann/loss_functions/empty_loss.hpp +1 -1
  69. mlpack/include/mlpack/methods/ann/loss_functions/mean_absolute_percentage_error.hpp +1 -1
  70. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +9 -1
  71. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +1 -1
  72. mlpack/include/mlpack/methods/lmnn/lmnn_impl.hpp +1 -1
  73. mlpack/kde.cp310-win_amd64.pyd +0 -0
  74. mlpack/kernel_pca.cp310-win_amd64.pyd +0 -0
  75. mlpack/kfn.cp310-win_amd64.pyd +0 -0
  76. mlpack/kmeans.cp310-win_amd64.pyd +0 -0
  77. mlpack/knn.cp310-win_amd64.pyd +0 -0
  78. mlpack/krann.cp310-win_amd64.pyd +0 -0
  79. mlpack/lars.cp310-win_amd64.pyd +0 -0
  80. mlpack/linear_regression_predict.cp310-win_amd64.pyd +0 -0
  81. mlpack/linear_regression_train.cp310-win_amd64.pyd +0 -0
  82. mlpack/linear_svm.cp310-win_amd64.pyd +0 -0
  83. mlpack/lmnn.cp310-win_amd64.pyd +0 -0
  84. mlpack/local_coordinate_coding.cp310-win_amd64.pyd +0 -0
  85. mlpack/logistic_regression.cp310-win_amd64.pyd +0 -0
  86. mlpack/lsh.cp310-win_amd64.pyd +0 -0
  87. mlpack/mean_shift.cp310-win_amd64.pyd +0 -0
  88. mlpack/nbc.cp310-win_amd64.pyd +0 -0
  89. mlpack/nca.cp310-win_amd64.pyd +0 -0
  90. mlpack/nmf.cp310-win_amd64.pyd +0 -0
  91. mlpack/pca.cp310-win_amd64.pyd +0 -0
  92. mlpack/perceptron.cp310-win_amd64.pyd +0 -0
  93. mlpack/preprocess_binarize.cp310-win_amd64.pyd +0 -0
  94. mlpack/preprocess_describe.cp310-win_amd64.pyd +0 -0
  95. mlpack/preprocess_one_hot_encoding.cp310-win_amd64.pyd +0 -0
  96. mlpack/preprocess_scale.cp310-win_amd64.pyd +0 -0
  97. mlpack/preprocess_split.cp310-win_amd64.pyd +0 -0
  98. mlpack/radical.cp310-win_amd64.pyd +0 -0
  99. mlpack/random_forest.cp310-win_amd64.pyd +0 -0
  100. mlpack/softmax_regression.cp310-win_amd64.pyd +0 -0
  101. mlpack/sparse_coding.cp310-win_amd64.pyd +0 -0
  102. mlpack-4.6.1.dist-info/DELVEWHEEL +2 -0
  103. {mlpack-4.6.0.dist-info → mlpack-4.6.1.dist-info}/METADATA +6 -2
  104. {mlpack-4.6.0.dist-info → mlpack-4.6.1.dist-info}/RECORD +106 -102
  105. {mlpack-4.6.0.dist-info → mlpack-4.6.1.dist-info}/WHEEL +1 -1
  106. mlpack/include/mlpack/core/data/load_numeric_csv.hpp +0 -192
  107. mlpack-4.6.0.dist-info/DELVEWHEEL +0 -2
  108. {mlpack-4.6.0.dist-info → mlpack-4.6.1.dist-info}/top_level.txt +0 -0
mlpack/__init__.py CHANGED
@@ -11,14 +11,14 @@ http://www.opensource.org/licenses/BSD-3-Clause for more information.
11
11
 
12
12
 
13
13
  # start delvewheel patch
14
- def _delvewheel_patch_1_10_0():
14
+ def _delvewheel_patch_1_10_1():
15
15
  import os
16
16
  if os.path.isdir(libs_dir := os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'mlpack.libs'))):
17
17
  os.add_dll_directory(libs_dir)
18
18
 
19
19
 
20
- _delvewheel_patch_1_10_0()
21
- del _delvewheel_patch_1_10_0
20
+ _delvewheel_patch_1_10_1()
21
+ del _delvewheel_patch_1_10_1
22
22
  # end delvewheel patch
23
23
 
24
24
  import warnings
@@ -74,4 +74,4 @@ from .adaboost import *
74
74
  from .linear_regression_train import linear_regression_train
75
75
  from .linear_regression_predict import linear_regression_predict
76
76
  from .linear_regression import *
77
- __version__='4.6.0'
77
+ __version__='4.6.1'
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -33,6 +33,7 @@
33
33
  #include <numeric>
34
34
  #include <vector>
35
35
  #include <queue>
36
+ #include <string>
36
37
 
37
38
  // But if it's not defined, we'll do it.
38
39
  #ifndef M_PI
@@ -12,7 +12,6 @@
12
12
  #ifndef MLPACK_CORE_DATA_DATA_HPP
13
13
  #define MLPACK_CORE_DATA_DATA_HPP
14
14
 
15
- #include "detect_file_type.hpp"
16
15
  #include "extension.hpp"
17
16
  #include "format.hpp"
18
17
  #include "has_serialize.hpp"
@@ -30,14 +29,19 @@
30
29
  #include "check_categorical_param.hpp"
31
30
  #include "confusion_matrix.hpp"
32
31
  #include "dataset_mapper.hpp"
32
+ #include "data_options.hpp"
33
+ #include "detect_file_type.hpp"
33
34
  #include "image_info.hpp"
34
35
  #include "image_resize_crop.hpp"
35
36
  #include "imputer.hpp"
36
37
  #include "is_naninf.hpp"
38
+ #include "matrix_options.hpp"
37
39
  #include "normalize_labels.hpp"
38
40
  #include "one_hot_encoding.hpp"
39
41
  #include "split_data.hpp"
40
42
  #include "string_algorithms.hpp"
43
+ #include "text_options.hpp"
41
44
  #include "types.hpp"
45
+ #include "utilities.hpp"
42
46
 
43
47
  #endif
@@ -0,0 +1,219 @@
1
+ /**
2
+ * @file core/data/data_options.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ *
6
+ * Data options, all possible options to load different data types and format
7
+ * with specific settings into mlpack.
8
+ *
9
+ * mlpack is free software; you may redistribute it and/or modify it under the
10
+ * terms of the 3-clause BSD license. You should have received a copy of the
11
+ * 3-clause BSD license along with mlpack. If not, see
12
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13
+ */
14
+ #ifndef MLPACK_CORE_DATA_DATA_OPTIONS_HPP
15
+ #define MLPACK_CORE_DATA_DATA_OPTIONS_HPP
16
+
17
+ #include <mlpack/prereqs.hpp>
18
+
19
+ #include "types.hpp"
20
+ #include "dataset_mapper.hpp"
21
+ #include "map_policies/map_policies.hpp"
22
+ #include "format.hpp"
23
+ #include "image_info.hpp"
24
+
25
+ namespace mlpack {
26
+ namespace data {
27
+
28
+ /**
29
+ * All possible DataOptions grouped under one class.
30
+ * This will allow us to have consistent data API for mlpack. If new data
31
+ * options might be necessary, then they should be added in the following.
32
+ */
33
+
34
+ template<typename Derived>
35
+ class DataOptionsBase
36
+ {
37
+ public:
38
+ DataOptionsBase(const bool fatal = defaultFatal,
39
+ const FileType format = defaultFormat) :
40
+ fatal(fatal),
41
+ format(format)
42
+ {
43
+ // Do nothing.
44
+ }
45
+
46
+ template<typename Derived2>
47
+ explicit DataOptionsBase(const DataOptionsBase<Derived2>& opts)
48
+ {
49
+ CopyOptions(opts);
50
+ }
51
+
52
+ template<typename Derived2>
53
+ explicit DataOptionsBase(DataOptionsBase<Derived2>&& opts)
54
+ {
55
+ MoveOptions(std::move(opts));
56
+ }
57
+
58
+ // Convert any other DataOptions type to this DataOptions type, printing
59
+ // warnings for any members that cannot be converted. If this object and
60
+ // `opts` are of the same type, then the constructor for that type will be
61
+ // called instead.
62
+ template<typename Derived2>
63
+ DataOptionsBase& operator=(const DataOptionsBase<Derived2>& other)
64
+ {
65
+ if ((void*) &other == (void*) this)
66
+ return *this;
67
+
68
+ // Print warnings for any members that cannot be converted.
69
+ const char* dataDesc = static_cast<const Derived&>(*this).DataDescription();
70
+ static_cast<const Derived2&>(other).WarnBaseConversion(dataDesc);
71
+
72
+ CopyOptions(other);
73
+ return *this;
74
+ }
75
+
76
+ // Take ownership of the options of another `DataOptionsBase` type.
77
+ template<typename Derived2>
78
+ DataOptionsBase& operator=(DataOptionsBase<Derived2>&& other)
79
+ {
80
+ if ((void*) &other != (void*) this)
81
+ return *this;
82
+
83
+ // Print warnings for any members that cannot be converted.
84
+ const char* dataDesc = static_cast<const Derived&>(*this).DataDescription();
85
+ static_cast<const Derived2&>(other).WarnBaseConversion(dataDesc);
86
+
87
+ MoveOptions(std::move(other));
88
+ return *this;
89
+ }
90
+
91
+ template<typename Derived2>
92
+ void CopyOptions(const DataOptionsBase<Derived2>& other)
93
+ {
94
+ // Only copy options that have been set in the other object.
95
+ if (other.fatal.has_value())
96
+ fatal = *other.fatal;
97
+ if (other.format.has_value())
98
+ format = *other.format;
99
+ }
100
+
101
+ template<typename Derived2>
102
+ void MoveOptions(DataOptionsBase<Derived2>&& other)
103
+ {
104
+ fatal = std::move(other.fatal);
105
+ format = std::move(other.format);
106
+
107
+ // Reset all of the options in the other object.
108
+ other.Reset();
109
+ }
110
+
111
+ void Reset()
112
+ {
113
+ fatal.reset();
114
+ format.reset();
115
+
116
+ // Reset any child members.
117
+ static_cast<Derived&>(*this).Reset();
118
+ }
119
+
120
+ // If true, then exceptions are thrown on failures.
121
+ const bool& Fatal() const { return AccessMember(fatal, defaultFatal); }
122
+ // Modify whether or not exceptions are thrown on failures.
123
+ bool& Fatal() { return ModifyMember(fatal, defaultFatal); }
124
+
125
+ // Get the type of the file that will be loaded.
126
+ const FileType& Format() const { return AccessMember(format, defaultFormat); }
127
+ // Modify the file format to load.
128
+ FileType& Format() { return ModifyMember(format, defaultFormat); }
129
+
130
+ /**
131
+ * Given a file type, return a logical name corresponding to that file type.
132
+ */
133
+ const std::string FileTypeToString() const
134
+ {
135
+ FileType f = format.has_value() ? *format : defaultFormat;
136
+ switch (f)
137
+ {
138
+ case FileType::CSVASCII: return "CSV data";
139
+ case FileType::RawASCII: return "raw ASCII formatted data";
140
+ case FileType::RawBinary: return "raw binary formatted data";
141
+ case FileType::ArmaASCII: return "Armadillo ASCII formatted data";
142
+ case FileType::ArmaBinary: return "Armadillo binary formatted data";
143
+ case FileType::PGMBinary: return "PGM data";
144
+ case FileType::PPMBinary: return "PGM data";
145
+ case FileType::HDF5Binary: return "HDF5 data";
146
+ case FileType::CoordASCII:
147
+ return "ASCII formatted sparse coordinate data";
148
+ case FileType::AutoDetect: return "Detect automatically data type";
149
+ case FileType::FileTypeUnknown: return "Unknown data type";
150
+ default: return "";
151
+ }
152
+ }
153
+
154
+ protected:
155
+ template<typename T>
156
+ const T& AccessMember(const std::optional<T>& value,
157
+ const T& defaultValue) const
158
+ {
159
+ if (value.has_value())
160
+ return *value;
161
+ else
162
+ return defaultValue;
163
+ }
164
+
165
+ template<typename T>
166
+ T& ModifyMember(std::optional<T>& value, const T defaultValue)
167
+ {
168
+ // Set the default value if needed so that (*value) has defined behavior
169
+ // according to the spec.
170
+ if (!value.has_value())
171
+ value = defaultValue;
172
+
173
+ return *value;
174
+ }
175
+
176
+ void WarnOptionConversion(const char* optionName, const char* dataType) const
177
+ {
178
+ if (fatal.has_value() && *fatal)
179
+ {
180
+ Log::Fatal << "Option '" << optionName << "' cannot be specified when "
181
+ << dataType << " is being loaded!" << std::endl;
182
+ }
183
+ else
184
+ {
185
+ Log::Warn << "Option '" << optionName << "' ignored; not applicable when "
186
+ << dataType << " is being loaded!" << std::endl;
187
+ }
188
+ }
189
+
190
+ private:
191
+ std::optional<bool> fatal;
192
+ std::optional<FileType> format;
193
+
194
+ constexpr static const bool defaultFatal = false;
195
+ constexpr static const FileType defaultFormat = FileType::AutoDetect;
196
+
197
+ // For access to internal optional members.
198
+ template<typename Derived2>
199
+ friend class DataOptionsBase;
200
+ };
201
+
202
+ // This utility class is meant to be used as the Derived parameter for an option
203
+ // that is not actually a derived type. It provides the WarnBaseConversion()
204
+ // member, which does nothing.
205
+ class EmptyOptions : public DataOptionsBase<EmptyOptions>
206
+ {
207
+ public:
208
+ void WarnBaseConversion(const char* /* dataDescription */) const { }
209
+ static const char* DataDescription() { return "general data"; }
210
+ void Reset() { }
211
+ };
212
+
213
+ using DataOptions = DataOptionsBase<EmptyOptions>;
214
+
215
+
216
+ } // namespace data
217
+ } // namespace mlpack
218
+
219
+ #endif
@@ -16,17 +16,13 @@
16
16
  #define MLPACK_CORE_DATA_DETECT_FILE_TYPE_HPP
17
17
 
18
18
  #include "types.hpp"
19
+ #include "extension.hpp"
20
+ #include "string_algorithms.hpp"
21
+ #include "text_options.hpp"
19
22
 
20
23
  namespace mlpack {
21
24
  namespace data {
22
25
 
23
- /**
24
- * Given a file type, return a logical name corresponding to that file type.
25
- *
26
- * @param type Type to get the logical name of.
27
- */
28
- inline std::string GetStringType(const FileType& type);
29
-
30
26
  /**
31
27
  * Given an istream, attempt to guess the file type. This is taken originally
32
28
  * from Armadillo's function guess_file_type_internal(), but we avoid using
@@ -62,7 +58,9 @@ inline FileType AutoDetect(std::fstream& stream,
62
58
  * @param filename Name of the file whose type we should detect.
63
59
  * @return Detected type of file. arma::file_type_unknown if unknown.
64
60
  */
65
- inline FileType DetectFromExtension(const std::string& filename);
61
+ template<typename MatType, typename DataOptionsType>
62
+ void DetectFromExtension(const std::string& filename,
63
+ DataOptionsType& opts);
66
64
 
67
65
  /**
68
66
  * Count the number of columns in the file. The file must be a CSV/TSV/TXT file
@@ -12,34 +12,11 @@
12
12
  * 3-clause BSD license along with mlpack. If not, see
13
13
  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
14
14
  */
15
- #include "extension.hpp"
16
15
  #include "detect_file_type.hpp"
17
- #include "string_algorithms.hpp"
18
16
 
19
17
  namespace mlpack {
20
18
  namespace data {
21
19
 
22
- /**
23
- * Given a file type, return a logical name corresponding to that file type.
24
- *
25
- * @param type Type to get the logical name of.
26
- */
27
- inline std::string GetStringType(const FileType& type)
28
- {
29
- switch (type)
30
- {
31
- case FileType::CSVASCII: return "CSV data";
32
- case FileType::RawASCII: return "raw ASCII formatted data";
33
- case FileType::RawBinary: return "raw binary formatted data";
34
- case FileType::ArmaASCII: return "Armadillo ASCII formatted data";
35
- case FileType::ArmaBinary: return "Armadillo binary formatted data";
36
- case FileType::PGMBinary: return "PGM data";
37
- case FileType::HDF5Binary: return "HDF5 data";
38
- case FileType::CoordASCII: return "ASCII formatted sparse coordinate data";
39
- default: return "";
40
- }
41
- }
42
-
43
20
  /**
44
21
  * Given an istream, attempt to guess the file type. This is taken originally
45
22
  * from Armadillo's function guess_file_type_internal(), but we avoid using
@@ -292,6 +269,11 @@ inline FileType AutoDetect(std::fstream& stream, const std::string& filename)
292
269
  {
293
270
  detectedLoadType = FileType::HDF5Binary;
294
271
  }
272
+ else if (extension == "arff")
273
+ {
274
+ return FileType::ARFFASCII;
275
+ }
276
+
295
277
  else // Unknown extension...
296
278
  {
297
279
  detectedLoadType = FileType::FileTypeUnknown;
@@ -306,34 +288,43 @@ inline FileType AutoDetect(std::fstream& stream, const std::string& filename)
306
288
  * @param filename Name of the file whose type we should detect.
307
289
  * @return Detected type of file.
308
290
  */
309
- inline FileType DetectFromExtension(const std::string& filename)
291
+ template<typename MatType, typename DataOptionsType>
292
+ void DetectFromExtension(const std::string& filename,
293
+ DataOptionsType& opts)
310
294
  {
311
295
  const std::string extension = Extension(filename);
312
296
 
313
297
  if (extension == "csv")
314
298
  {
315
- return FileType::CSVASCII;
299
+ opts.Format() = FileType::CSVASCII;
316
300
  }
317
301
  else if (extension == "txt")
318
302
  {
319
- return FileType::RawASCII;
303
+ if (IsSparseMat<MatType>::value)
304
+ opts.Format() = FileType::CoordASCII;
305
+ else
306
+ opts.Format() = FileType::RawASCII;
320
307
  }
321
308
  else if (extension == "bin")
322
309
  {
323
- return FileType::ArmaBinary;
310
+ opts.Format() = FileType::ArmaBinary;
324
311
  }
325
312
  else if (extension == "pgm")
326
313
  {
327
- return FileType::PGMBinary;
314
+ opts.Format() = FileType::PGMBinary;
328
315
  }
329
316
  else if (extension == "h5" || extension == "hdf5" || extension == "hdf" ||
330
317
  extension == "he5")
331
318
  {
332
- return FileType::HDF5Binary;
319
+ opts.Format() = FileType::HDF5Binary;
320
+ }
321
+ else if (extension == "arff")
322
+ {
323
+ opts.Format() = FileType::ARFFASCII;
333
324
  }
334
325
  else
335
326
  {
336
- return FileType::FileTypeUnknown;
327
+ opts.Format() = FileType::FileTypeUnknown;
337
328
  }
338
329
  }
339
330
 
@@ -1,6 +1,7 @@
1
1
  /**
2
2
  * @file core/data/load.hpp
3
3
  * @author Ryan Curtin
4
+ * @author Omar Shrit
4
5
  *
5
6
  * Load an Armadillo matrix from file. This is necessary because Armadillo does
6
7
  * not transpose matrices on input, and it allows us to give better error
@@ -15,19 +16,55 @@
15
16
  #define MLPACK_CORE_DATA_LOAD_HPP
16
17
 
17
18
  #include <mlpack/prereqs.hpp>
18
- #include <string>
19
19
 
20
+ #include "text_options.hpp"
20
21
  #include "format.hpp"
21
22
  #include "dataset_mapper.hpp"
22
23
  #include "detect_file_type.hpp"
23
24
  #include "image_info.hpp"
24
- #include "load_csv.hpp"
25
25
  #include "load_arff.hpp"
26
+ #include "load_numeric.hpp"
27
+ #include "load_categorical.hpp"
26
28
  #include "load_image.hpp"
29
+ #include "utilities.hpp"
27
30
 
28
31
  namespace mlpack {
29
32
  namespace data /** Functions to load and save matrices and models. */ {
30
33
 
34
+ /**
35
+ * Loads a matrix from file, guessing the filetype from the extension. This
36
+ * will load with the options specified in `opts`.
37
+ *
38
+ * @param filename Name of file to load.
39
+ * @param matrix Matrix to load contents of file into.
40
+ * @param opts DataOptions to be passed to the function
41
+ * @return Boolean value indicating success or failure of load.
42
+ */
43
+ template<typename MatType, typename DataOptionsType>
44
+ bool Load(const std::string& filename,
45
+ MatType& matrix,
46
+ DataOptionsType& opts,
47
+ std::enable_if_t<IsArma<MatType>::value ||
48
+ IsSparseMat<MatType>::value>* = 0,
49
+ std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>* = 0);
50
+
51
+ /**
52
+ * Loads a matrix from file, guessing the filetype from the extension. This
53
+ * will load with the options specified in `opts`.
54
+ *
55
+ * @param filename Name of file to load.
56
+ * @param matrix Matrix to load contents of file into.
57
+ * @param opts Non-modifiable DataOptions to be passed to the function
58
+ * @return Boolean value indicating success or failure of load.
59
+ */
60
+ template<typename MatType, typename DataOptionsType>
61
+ bool Load(const std::string& filename,
62
+ MatType& matrix,
63
+ const DataOptionsType& opts,
64
+ std::enable_if_t<IsArma<MatType>::value ||
65
+ IsSparseMat<MatType>::value>* = 0,
66
+ std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>* = 0);
67
+
31
68
  /**
32
69
  * Loads a matrix from file, guessing the filetype from the extension. This
33
70
  * will transpose the matrix at load time (unless the transpose parameter is set
@@ -250,7 +287,8 @@ bool Load(const std::string& filename,
250
287
  const std::string& name,
251
288
  T& t,
252
289
  const bool fatal = false,
253
- format f = format::autodetect);
290
+ format f = format::autodetect,
291
+ std::enable_if_t<HasSerialize<T>::value>* = 0);
254
292
 
255
293
  } // namespace data
256
294
  } // namespace mlpack
@@ -25,7 +25,7 @@ namespace data {
25
25
  * if any features are non-numeric.
26
26
  */
27
27
  template<typename eT>
28
- void LoadARFF(const std::string& filename, arma::Mat<eT>& matrix);
28
+ bool LoadARFF(const std::string& filename, arma::Mat<eT>& matrix);
29
29
 
30
30
  /**
31
31
  * A utility function to load an ARFF dataset as numeric and categorical
@@ -50,9 +50,10 @@ void LoadARFF(const std::string& filename, arma::Mat<eT>& matrix);
50
50
  * from another call to LoadARFF().
51
51
  */
52
52
  template<typename eT, typename PolicyType>
53
- void LoadARFF(const std::string& filename,
53
+ bool LoadARFF(const std::string& filename,
54
54
  arma::Mat<eT>& matrix,
55
- DatasetMapper<PolicyType>& info);
55
+ DatasetMapper<PolicyType>& info,
56
+ bool fatal);
56
57
 
57
58
  } // namespace data
58
59
  } // namespace mlpack