mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -17,171 +17,48 @@
17
17
  // In case it hasn't already been included.
18
18
  #include "load.hpp"
19
19
 
20
- #include <algorithm>
21
- #include <exception>
22
-
23
- #include "extension.hpp"
24
- #include "string_algorithms.hpp"
25
-
26
20
  namespace mlpack {
27
- namespace data {
28
-
29
- // The following functions are kept for backward compatibility,
30
- // Please remove them when we release mlpack 5.
31
- template<typename eT>
32
- bool Load(const std::string& filename,
33
- arma::Mat<eT>& matrix,
34
- const bool fatal,
35
- const bool transpose,
36
- const FileType inputLoadType)
37
- {
38
- MatrixOptions opts;
39
- opts.Fatal() = fatal;
40
- opts.NoTranspose() = !transpose;
41
- opts.Format() = inputLoadType;
42
-
43
- return Load(filename, matrix, opts);
44
- }
45
-
46
- // For loading data into sparse matrix
47
- template <typename eT>
48
- bool Load(const std::string& filename,
49
- arma::SpMat<eT>& matrix,
50
- const bool fatal,
51
- const bool transpose,
52
- const FileType inputLoadType)
53
- {
54
- MatrixOptions opts;
55
- opts.Fatal() = fatal;
56
- opts.NoTranspose() = !transpose;
57
- opts.Format() = inputLoadType;
58
-
59
- return Load(filename, matrix, opts);
60
- }
61
-
62
- // For loading data into a column vector
63
- template <typename eT>
64
- bool Load(const std::string& filename,
65
- arma::Col<eT>& vec,
66
- const bool fatal)
67
- {
68
- DataOptions opts;
69
- opts.Fatal() = fatal;
70
- return Load(filename, vec, opts);
71
- }
72
-
73
- // For loading data into a raw vector
74
- template <typename eT>
75
- bool Load(const std::string& filename,
76
- arma::Row<eT>& rowvec,
77
- const bool fatal)
78
- {
79
- DataOptions opts;
80
- opts.Fatal() = fatal;
81
- return Load(filename, rowvec, opts);
82
- }
83
-
84
- // Load with mappings. Unfortunately we have to implement this ourselves.
85
- template<typename eT, typename PolicyType>
86
- bool Load(const std::string& filename,
87
- arma::Mat<eT>& matrix,
88
- DatasetMapper<PolicyType>& info,
89
- const bool fatal,
90
- const bool transpose)
91
- {
92
- TextOptions opts;
93
- opts.Fatal() = fatal;
94
- opts.NoTranspose() = !transpose;
95
- opts.Categorical() = true;
96
-
97
- if constexpr (std::is_same_v<PolicyType, data::IncrementPolicy>)
98
- {
99
- opts.DatasetInfo() = info;
100
- }
101
- else if constexpr (std::is_same_v<PolicyType, data::MissingPolicy>)
102
- {
103
- opts.MissingPolicy() = true;
104
- opts.DatasetMissingPolicy() = info;
105
- }
106
-
107
- bool success = Load(filename, matrix, opts);
108
-
109
- if constexpr (std::is_same_v<PolicyType, data::IncrementPolicy>)
110
- {
111
- info = opts.DatasetInfo();
112
- }
113
- else if constexpr (std::is_same_v<PolicyType, data::MissingPolicy>)
114
- {
115
- info = opts.DatasetMissingPolicy();
116
- }
117
-
118
- return success;
119
- }
120
21
 
121
22
  template<typename MatType, typename DataOptionsType>
122
23
  bool Load(const std::string& filename,
123
24
  MatType& matrix,
124
25
  const DataOptionsType& opts,
125
- std::enable_if_t<IsArma<MatType>::value ||
126
- IsSparseMat<MatType>::value>*,
127
- std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
26
+ const typename std::enable_if_t<
27
+ IsDataOptions<DataOptionsType>::value>*)
128
28
  {
129
29
  DataOptionsType tmpOpts(opts);
130
- return Load(filename, matrix, tmpOpts);
30
+ return Load(filename, matrix, tmpOpts, false);
131
31
  }
132
32
 
133
- template<typename MatType>
134
- bool LoadMatrix(const std::string& filename,
135
- MatType& matrix,
136
- std::fstream& stream,
137
- TextOptions& txtOpts)
33
+ template<typename eT, typename DataOptionsType>
34
+ bool Load(const std::vector<std::string>& files,
35
+ arma::Mat<eT>& matrix,
36
+ const DataOptionsType& opts,
37
+ const typename std::enable_if_t<
38
+ IsDataOptions<DataOptionsType>::value>*)
138
39
  {
139
- bool success = false;
140
- if constexpr (IsSparseMat<MatType>::value)
141
- {
142
- success = LoadSparse(filename, matrix, txtOpts, stream);
143
- }
144
- else if (txtOpts.Categorical() ||
145
- (txtOpts.Format() == FileType::ARFFASCII))
146
- {
147
- success = LoadCategorical(filename, matrix, txtOpts);
148
- }
149
- else if constexpr (IsCol<MatType>::value)
150
- {
151
- success = LoadCol(filename, matrix, txtOpts, stream);
152
- }
153
- else if constexpr (IsRow<MatType>::value)
154
- {
155
- success = LoadRow(filename, matrix, txtOpts, stream);
156
- }
157
- else if constexpr (IsDense<MatType>::value)
158
- {
159
- success = LoadDense(filename, matrix, txtOpts, stream);
160
- }
161
- else
162
- {
163
- if (txtOpts.Fatal())
164
- Log::Fatal << "data::Load(): unknown matrix-like type given!"
165
- << std::endl;
166
- else
167
- Log::Warn << "data::Load(): unknown matrix-like type given!"
168
- << std::endl;
169
-
170
- return false;
171
- }
172
- return success;
40
+ DataOptionsType tmpOpts(opts);
41
+ return Load(files, matrix, tmpOpts, false);
173
42
  }
174
43
 
175
- template<typename MatType, typename DataOptionsType>
44
+ template<typename ObjectType, typename DataOptionsType>
176
45
  bool Load(const std::string& filename,
177
- MatType& matrix,
46
+ ObjectType& matrix,
178
47
  DataOptionsType& opts,
179
- std::enable_if_t<IsArma<MatType>::value ||
180
- IsSparseMat<MatType>::value>*,
181
- std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
48
+ const bool copyBack,
49
+ const typename std::enable_if_t<
50
+ IsDataOptions<DataOptionsType>::value>*)
182
51
  {
183
52
  Timer::Start("loading_data");
184
53
 
54
+ static_assert(!IsArma<ObjectType>::value || !IsSparseMat<ObjectType>::value
55
+ || !HasSerialize<ObjectType>::value, "mlpack can load Armadillo"
56
+ " matrices or serialized mlpack models only; please use a known type.");
57
+ const bool isMatrixType = IsArma<ObjectType>::value ||
58
+ IsSparseMat<ObjectType>::value;
59
+ const bool isSerializable = HasSerialize<ObjectType>::value;
60
+ const bool isSparseMatrixType = IsSparseMat<ObjectType>::value;
61
+
185
62
  std::fstream stream;
186
63
  bool success = OpenFile(filename, opts, true, stream);
187
64
  if (!success)
@@ -190,43 +67,65 @@ bool Load(const std::string& filename,
190
67
  return false;
191
68
  }
192
69
 
193
- success = DetectFileType<MatType>(filename, opts, true, &stream);
70
+ success = DetectFileType<ObjectType>(filename, opts, true, &stream);
194
71
  if (!success)
195
72
  {
196
73
  Timer::Stop("loading_data");
197
74
  return false;
198
75
  }
76
+ const bool isImageFormat = (opts.Format() == FileType::PNG ||
77
+ opts.Format() == FileType::JPG || opts.Format() == FileType::PNM ||
78
+ opts.Format() == FileType::BMP || opts.Format() == FileType::GIF ||
79
+ opts.Format() == FileType::PSD || opts.Format() == FileType::TGA ||
80
+ opts.Format() == FileType::PIC || opts.Format() == FileType::ImageType);
199
81
 
200
- if constexpr (IsArma<MatType>::value || IsSparseMat<MatType>::value)
82
+ if constexpr (isMatrixType)
201
83
  {
202
- TextOptions txtOpts(std::move(opts));
203
- success = LoadMatrix(filename, matrix, stream, txtOpts);
204
- opts = std::move(txtOpts);
84
+ if (isImageFormat)
85
+ {
86
+ if constexpr (isSparseMatrixType)
87
+ {
88
+ return HandleError("Cannot load image data into a sparse matrix. "
89
+ "Please use dense matrix instead.", opts);
90
+ }
91
+ else
92
+ {
93
+ ImageOptions imgOpts(std::move(opts));
94
+ std::vector<std::string> files;
95
+ files.push_back(filename);
96
+ success = LoadImage(files, matrix, imgOpts);
97
+ if (copyBack)
98
+ opts = std::move(imgOpts);
99
+ }
100
+ }
101
+ else
102
+ {
103
+ TextOptions txtOpts(std::move(opts));
104
+ success = LoadNumeric(filename, matrix, stream, txtOpts);
105
+ if (copyBack)
106
+ opts = std::move(txtOpts);
107
+ }
108
+ }
109
+ else if constexpr (isSerializable)
110
+ {
111
+ success = LoadModel(matrix, opts, stream);
205
112
  }
206
113
  else
207
114
  {
208
- if (opts.Fatal())
209
- Log::Fatal << "DataOptionsType is unknown! Please use a known type "
210
- << "or provide specific overloads." << std::endl;
211
- else
212
- Log::Warn << "DataOptionsType is unknown! Please use a known type "
213
- << "or provide specific overloads." << std::endl;
214
- return false;
115
+ return HandleError("DataOptionsType is unknown! Please use a known type "
116
+ "or provide specific overloads.", opts);
215
117
  }
216
118
 
217
119
  if (!success)
218
120
  {
219
121
  Timer::Stop("loading_data");
220
- if (opts.Fatal())
221
- Log::Fatal << "Loading from '" << filename << "' failed." << std::endl;
222
- else
223
- Log::Warn << "Loading from '" << filename << "' failed." << std::endl;
224
-
225
- return false;
122
+ std::stringstream oss;
123
+ oss << "Loading from '" << filename << "' failed.";
124
+ return HandleError(oss, opts);
226
125
  }
227
126
  else
228
127
  {
229
- if constexpr (IsArma<MatType>::value)
128
+ if constexpr (IsArma<ObjectType>::value)
230
129
  {
231
130
  Log::Info << "Size is " << matrix.n_rows << " x "
232
131
  << matrix.n_cols << ".\n";
@@ -238,115 +137,42 @@ bool Load(const std::string& filename,
238
137
  return success;
239
138
  }
240
139
 
241
- template<typename MatType>
242
- bool LoadDense(const std::string& filename,
243
- MatType& matrix,
244
- TextOptions& opts,
245
- std::fstream& stream)
246
- {
247
- bool success;
248
- if (opts.Format() != FileType::RawBinary)
249
- Log::Info << "Loading '" << filename << "' as "
250
- << opts.FileTypeToString() << ". " << std::flush;
251
-
252
- // We can't use the stream if the type is HDF5.
253
- if (opts.Format() == FileType::HDF5Binary)
254
- {
255
- success = LoadHDF5(filename, matrix, opts);
256
- }
257
- else if (opts.Format() == FileType::CSVASCII)
258
- {
259
- success = LoadCSVASCII(filename, matrix, opts);
260
-
261
- if (matrix.col(0).is_zero())
262
- Log::Warn << "data::Load(): the first line in '" << filename << "' was "
263
- << "loaded as all zeros; if the first row is headers, specify "
264
- << "`HasHeaders() = true` in the given DataOptions." << std::endl;
265
- }
266
- else
267
- {
268
- if (opts.Format() == FileType::RawBinary)
269
- Log::Warn << "Loading '" << filename << "' as "
270
- << opts.FileTypeToString() << "; "
271
- << "but this may not be the actual filetype!" << std::endl;
272
-
273
- success = matrix.load(stream, ToArmaFileType(opts.Format()));
274
- if (!opts.NoTranspose())
275
- inplace_trans(matrix);
276
- }
277
- return success;
278
- }
279
-
280
- template <typename eT>
281
- bool LoadSparse(const std::string& filename,
282
- arma::SpMat<eT>& matrix,
283
- TextOptions& opts,
284
- std::fstream& stream)
140
+ template<typename eT, typename DataOptionsType>
141
+ bool Load(const std::vector<std::string>& files,
142
+ arma::Mat<eT>& matrix,
143
+ DataOptionsType& opts,
144
+ const bool copyBack,
145
+ const typename std::enable_if_t<
146
+ IsDataOptions<DataOptionsType>::value>*)
285
147
  {
286
- bool success;
287
- // There is still a small amount of differentiation that needs to be done:
288
- // if we got a text type, it could be a coordinate list. We will make an
289
- // educated guess based on the shape of the input.
290
- if (opts.Format() == FileType::RawASCII)
148
+ bool success = false;
149
+ if (files.empty())
291
150
  {
292
- // Get the number of columns in the file. If it is the right shape, we
293
- // will assume it is sparse.
294
- const size_t cols = CountCols(stream);
295
- if (cols == 3)
296
- {
297
- // We have the right number of columns, so assume the type is a
298
- // coordinate list.
299
- opts.Format() = FileType::CoordASCII;
300
- }
151
+ return HandleError("Load(): given set of filenames is empty;"
152
+ " loading failed.", opts);
301
153
  }
302
154
 
303
- // Filter out invalid types.
304
- if ((opts.Format() == FileType::PGMBinary) ||
305
- (opts.Format() == FileType::PPMBinary) ||
306
- (opts.Format() == FileType::ArmaASCII) ||
307
- (opts.Format() == FileType::RawBinary))
308
- {
309
- if (opts.Fatal())
310
- Log::Fatal << "Cannot load '" << filename << "' with type "
311
- << opts.FileTypeToString() << " into a sparse matrix; format is "
312
- << "only supported for dense matrices." << std::endl;
313
- else
314
- Log::Warn << "Cannot load '" << filename << "' with type "
315
- << opts.FileTypeToString() << " into a sparse matrix; format is "
316
- << "only supported for dense matrices; load failed." << std::endl;
155
+ DetectFromExtension<arma::Mat<eT>>(files.back(), opts);
156
+ const bool isImageFormat = (opts.Format() == FileType::PNG ||
157
+ opts.Format() == FileType::JPG || opts.Format() == FileType::PNM ||
158
+ opts.Format() == FileType::BMP || opts.Format() == FileType::GIF ||
159
+ opts.Format() == FileType::PSD || opts.Format() == FileType::TGA ||
160
+ opts.Format() == FileType::PIC || opts.Format() == FileType::ImageType);
317
161
 
318
- return false;
319
- }
320
- else if (opts.Format() == FileType::CSVASCII)
162
+ if (isImageFormat)
321
163
  {
322
- // Armadillo sparse matrices can't load CSVs, so we have to load a separate
323
- // matrix to do that. If the CSV has three columns, we assume it's a
324
- // coordinate list.
325
- arma::Mat<eT> dense;
326
- success = dense.load(stream, ToArmaFileType(opts.Format()));
327
- if (dense.n_cols == 3)
328
- {
329
- arma::umat locations = arma::conv_to<arma::umat>::from(
330
- dense.cols(0, 1).t());
331
- matrix = arma::SpMat<eT>(locations, dense.col(2));
332
- }
333
- else
334
- {
335
- matrix = arma::conv_to<arma::SpMat<eT>>::from(dense);
336
- }
164
+ ImageOptions imgOpts(std::move(opts));
165
+ success = LoadImage(files, matrix, imgOpts);
166
+ if (copyBack)
167
+ opts = std::move(imgOpts);
337
168
  }
338
169
  else
339
170
  {
340
- success = matrix.load(stream, ToArmaFileType(opts.Format()));
341
- }
342
-
343
- if (!opts.NoTranspose())
344
- {
345
- // It seems that there is no direct way to use inplace_trans() on
346
- // sparse matrices.
347
- matrix = matrix.t();
171
+ TextOptions txtOpts(std::move(opts));
172
+ success = LoadNumericMultifile(files, matrix, txtOpts);
173
+ if (copyBack)
174
+ opts = std::move(txtOpts);
348
175
  }
349
-
350
176
  return success;
351
177
  }
352
178
 
@@ -393,14 +219,10 @@ bool LoadCategorical(const std::string& filename,
393
219
  {
394
220
  // The type is unknown.
395
221
  Timer::Stop("loading_data");
396
- if (opts.Fatal())
397
- Log::Fatal << "Unable to detect type of '" << filename << "'; "
398
- << "Incorrect extension?" << std::endl;
399
- else
400
- Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
401
- << " Incorrect extension?" << std::endl;
402
-
403
- return false;
222
+ std::stringstream oss;
223
+ oss << "Unable to detect type of '" << filename << "'; "
224
+ << "Incorrect extension?";
225
+ return HandleError(oss, opts);
404
226
  }
405
227
 
406
228
  Log::Info << "Size is " << matrix.n_rows << " x " << matrix.n_cols << ".\n";
@@ -410,7 +232,6 @@ bool LoadCategorical(const std::string& filename,
410
232
  return true;
411
233
  }
412
234
 
413
- } // namespace data
414
235
  } // namespace mlpack
415
236
 
416
237
  #endif
@@ -0,0 +1,62 @@
1
+ /**
2
+ * @file core/data/load_model.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ *
6
+ * Intenal implementation of model-specific Load() function.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_CORE_DATA_LOAD_MODEL_HPP
14
+ #define MLPACK_CORE_DATA_LOAD_MODEL_HPP
15
+
16
+ #include <cereal/archives/xml.hpp>
17
+ #include <cereal/archives/binary.hpp>
18
+ #include <cereal/archives/json.hpp>
19
+
20
+ #include "text_options.hpp"
21
+
22
+ namespace mlpack {
23
+
24
+ template<typename Object>
25
+ bool LoadModel(Object& objectToSerialize,
26
+ DataOptionsBase<PlainDataOptions>& opts,
27
+ std::fstream& stream)
28
+ {
29
+ try
30
+ {
31
+ if (opts.Format() == FileType::XML)
32
+ {
33
+ cereal::XMLInputArchive ar(stream);
34
+ ar(cereal::make_nvp("model", objectToSerialize));
35
+ }
36
+ else if (opts.Format() == FileType::JSON)
37
+ {
38
+ cereal::JSONInputArchive ar(stream);
39
+ ar(cereal::make_nvp("model", objectToSerialize));
40
+ }
41
+ else if (opts.Format() == FileType::BIN)
42
+ {
43
+ cereal::BinaryInputArchive ar(stream);
44
+ ar(cereal::make_nvp("model", objectToSerialize));
45
+ }
46
+
47
+ return true;
48
+ }
49
+ catch (cereal::Exception& e)
50
+ {
51
+ if (opts.Fatal())
52
+ Log::Fatal << e.what() << std::endl;
53
+ else
54
+ Log::Warn << e.what() << std::endl;
55
+
56
+ return false;
57
+ }
58
+ }
59
+
60
+ } // namespace mlpack
61
+
62
+ #endif