mlpack 4.6.2__cp38-cp38-win_amd64.whl → 4.7.0__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (414) hide show
  1. mlpack/__init__.py +3 -3
  2. mlpack/adaboost_classify.cp38-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp38-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp38-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp38-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp38-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp38-win_amd64.pyd +0 -0
  8. mlpack/cf.cp38-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp38-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp38-win_amd64.pyd +0 -0
  11. mlpack/det.cp38-win_amd64.pyd +0 -0
  12. mlpack/emst.cp38-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp38-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp38-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp38-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp38-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp38-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp38-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp38-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp38-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp38-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp38-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp38-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp38-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp38-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp38-win_amd64.pyd +0 -0
  365. mlpack/knn.cp38-win_amd64.pyd +0 -0
  366. mlpack/krann.cp38-win_amd64.pyd +0 -0
  367. mlpack/lars.cp38-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp38-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp38-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp38-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp38-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp38-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp38-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp38-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp38-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp38-win_amd64.pyd +0 -0
  377. mlpack/nca.cp38-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp38-win_amd64.pyd +0 -0
  379. mlpack/pca.cp38-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp38-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp38-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp38-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp38-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp38-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp38-win_amd64.pyd +0 -0
  386. mlpack/radical.cp38-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp38-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp38-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp38-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +5 -5
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +395 -376
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{.load-order-mlpack-4.6.2 → .load-order-mlpack-4.7.0} +0 -0
@@ -1,8 +1,9 @@
1
1
  /**
2
2
  * @file core/data/save_image.hpp
3
3
  * @author Ryan Curtin
4
+ * @author Omar Shrit
4
5
  *
5
- * Implementation of save functionality.
6
+ * Implementation of save image functionality.
6
7
  *
7
8
  * mlpack is free software; you may redistribute it and/or modify it under the
8
9
  * terms of the 3-clause BSD license. You should have received a copy of the
@@ -13,54 +14,93 @@
13
14
  #define MLPACK_CORE_DATA_SAVE_IMAGE_HPP
14
15
 
15
16
  #include <mlpack/core/stb/stb.hpp>
16
-
17
- #include "image_info.hpp"
17
+ #include <mlpack/core/math/make_alias.hpp>
18
18
 
19
19
  namespace mlpack {
20
- namespace data {
21
20
 
22
- /**
23
- * Save the image file from the given matrix.
24
- *
25
- * @param filename Name of the image file.
26
- * @param matrix Matrix to save the image from.
27
- * @param info An object of ImageInfo class.
28
- * @param fatal If an error should be reported as fatal (default false).
29
- * @return Boolean value indicating success or failure of load.
30
- */
31
21
  template<typename eT>
32
- bool Save(const std::string& filename,
33
- arma::Mat<eT>& matrix,
34
- ImageInfo& info,
35
- const bool fatal = false);
22
+ bool SaveImage(const std::vector<std::string>& files,
23
+ const arma::Mat<eT>& matrix,
24
+ ImageOptions& opts)
25
+ {
26
+ if (files.empty())
27
+ {
28
+ std::stringstream oss;
29
+ oss << "Save(): vector of image files is empty; nothing to save.";
30
+ return HandleError(oss, opts);
31
+ }
36
32
 
37
- /**
38
- * Save the image file from the given matrix.
39
- *
40
- * @param files A vector consisting of filenames.
41
- * @param matrix Matrix to save the image from.
42
- * @param info An object of ImageInfo class.
43
- * @param fatal If an error should be reported as fatal (default false).
44
- * @return Boolean value indicating success or failure of load.
45
- */
46
- template<typename eT>
47
- bool Save(const std::vector<std::string>& files,
48
- arma::Mat<eT>& matrix,
49
- ImageInfo& info,
50
- const bool fatal = false);
33
+ // Check if we do have any type that is not supported.
34
+ if (opts.Format() == FileType::ImageType ||
35
+ opts.Format() == FileType::AutoDetect)
36
+ {
37
+ for (size_t i = 0; i < files.size() ; ++i)
38
+ {
39
+ if (!opts.saveType.count(Extension(files.at(i))))
40
+ {
41
+ std::stringstream oss;
42
+ oss << "Save(): file type " << opts.FileTypeToString()
43
+ << " isn't supported. Currently image saving supports: ";
44
+ for (const auto& x : opts.saveType)
45
+ oss << " " << x;
46
+ oss << "." << std::endl;
47
+ return HandleError(oss, opts);
48
+ }
49
+ }
50
+ }
51
51
 
52
- /**
53
- * Helper function to save files. Implementation in save_image.hpp.
54
- */
55
- inline bool SaveImage(const std::string& filename,
56
- arma::Mat<unsigned char>& image,
57
- ImageInfo& info,
58
- const bool fatal = false);
52
+ size_t dimension = opts.Width() * opts.Height() * opts.Channels() *
53
+ files.size();
54
+ // We only need to check the rows since it is a matrix.
55
+ if (dimension != matrix.n_rows * matrix.n_cols)
56
+ {
57
+ std::stringstream oss;
58
+ oss << "Save(): The given image dimensions, Width: " << opts.Width()
59
+ << ", Height: " << opts.Height() << ", Channels: "<< opts.Channels()
60
+ << " do not match the dimensions of the matrix to be saved!";
61
+ return HandleError(oss, opts);
62
+ }
63
+ // Unfortunately we cannot move because matrix is const.
64
+ arma::Mat<unsigned char> tempMatrix =
65
+ arma::conv_to<arma::Mat<unsigned char>>::from(matrix);
66
+ bool success = false;
67
+ for (size_t i = 0; i < files.size() ; ++i)
68
+ {
69
+ // Update opts.Format() at each iteration.
70
+ DetectFromExtension<arma::Mat<eT>, ImageOptions>(files.at(i), opts);
71
+ if (opts.Format() == FileType::PNG)
72
+ {
73
+ success = stbi_write_png(files.at(i).c_str(), opts.Width(), opts.Height(),
74
+ opts.Channels(), tempMatrix.colptr(i),
75
+ opts.Width() * opts.Channels());
76
+ }
77
+ else if (opts.Format() == FileType::BMP)
78
+ {
79
+ success = stbi_write_bmp(files.at(i).c_str(), opts.Width(), opts.Height(),
80
+ opts.Channels(), tempMatrix.colptr(i));
81
+ }
82
+ else if (opts.Format() == FileType::TGA)
83
+ {
84
+ success = stbi_write_tga(files.at(i).c_str(), opts.Width(), opts.Height(),
85
+ opts.Channels(), tempMatrix.colptr(i));
86
+ }
87
+ else if (opts.Format() == FileType::JPG)
88
+ {
89
+ success = stbi_write_jpg(files.at(i).c_str(), opts.Width(), opts.Height(),
90
+ opts.Channels(), tempMatrix.colptr(i), opts.Quality());
91
+ }
92
+
93
+ if (!success)
94
+ {
95
+ std::stringstream oss;
96
+ oss << "Save(): error saving image to '" << files.at(i) << "'.";
97
+ return HandleError(oss, opts);
98
+ }
99
+ }
59
100
 
60
- } //namespace data
61
- } //namespace mlpack
101
+ return success;
102
+ }
62
103
 
63
- // Include implementation of Save() for images.
64
- #include "save_image_impl.hpp"
104
+ } // namespace mlpack
65
105
 
66
106
  #endif
@@ -15,164 +15,109 @@
15
15
 
16
16
  // In case it hasn't already been included.
17
17
  #include "save.hpp"
18
- #include "extension.hpp"
19
18
 
20
19
  namespace mlpack {
21
- namespace data {
22
-
23
- template<typename eT>
24
- bool Save(const std::string& filename,
25
- const arma::Col<eT>& vec,
26
- const bool fatal,
27
- FileType inputSaveType)
28
- {
29
- // Don't transpose: one observation per line (for CSVs at least).
30
- return Save(filename, vec, fatal, false, inputSaveType);
31
- }
32
-
33
- template<typename eT>
34
- bool Save(const std::string& filename,
35
- const arma::Row<eT>& rowvec,
36
- const bool fatal,
37
- FileType inputSaveType)
38
- {
39
- return Save(filename, rowvec, fatal, true, inputSaveType);
40
- }
41
-
42
- // Save a Sparse Matrix
43
- template<typename eT>
44
- bool Save(const std::string& filename,
45
- const arma::SpMat<eT>& matrix,
46
- const bool fatal,
47
- bool transpose)
48
- {
49
- MatrixOptions opts;
50
- opts.Fatal() = fatal;
51
- opts.NoTranspose() = !transpose;
52
-
53
- return Save(filename, matrix, opts);
54
- }
55
-
56
- template<typename eT>
57
- bool Save(const std::string& filename,
58
- const arma::Mat<eT>& matrix,
59
- const bool fatal,
60
- bool transpose,
61
- FileType inputSaveType)
62
- {
63
- MatrixOptions opts;
64
- opts.Fatal() = fatal;
65
- opts.NoTranspose() = !transpose;
66
- opts.Format() = inputSaveType;
67
-
68
- return Save(filename, matrix, opts);
69
- }
70
20
 
71
21
  template<typename MatType, typename DataOptionsType>
72
22
  bool Save(const std::string& filename,
73
23
  const MatType& matrix,
74
24
  const DataOptionsType& opts,
75
- std::enable_if_t<IsArma<MatType>::value ||
76
- IsSparseMat<MatType>::value>*)
25
+ const typename std::enable_if_t<
26
+ IsDataOptions<DataOptionsType>::value>*)
77
27
  {
78
28
  //! just use default copy ctor with = operator and make a copy.
79
29
  DataOptionsType copyOpts(opts);
80
30
  return Save(filename, matrix, copyOpts);
81
31
  }
82
32
 
83
- /*
84
- * Add this SFINAE in here because the compiler is so stupid that it is not
85
- * able to distinguish between these two:
86
- *
87
- * data::Save(filename, "model", *output);
88
- *
89
- * and
90
- *
91
- * data::Save(filename, matrix, opts);
92
- *
93
- * The second SFINAE is added because the compiler is bot able to see the
94
- * difference between:
95
- *
96
- * data::Save(filename, Row/Col, fatal);
97
- *
98
- * and
99
- *
100
- * data::Save(filename, Row/Col, Opts);
101
- *
102
- * This SFINAE is temporary and must be removed after the integration of stage 3 or
103
- * when the compiler becomes more intelligent.
104
- */
105
- template<typename MatType, typename DataOptionsType>
33
+ template<typename ObjectType, typename DataOptionsType>
106
34
  bool Save(const std::string& filename,
107
- const MatType& matrix,
35
+ const ObjectType& matrix,
108
36
  DataOptionsType& opts,
109
- std::enable_if_t<IsArma<MatType>::value ||
110
- IsSparseMat<MatType>::value>*,
111
- std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
37
+ const typename std::enable_if_t<
38
+ IsDataOptions<DataOptionsType>::value>*)
112
39
  {
113
40
  Timer::Start("saving_data");
114
-
115
- bool success = DetectFileType<MatType>(filename, opts, false);
41
+ static_assert(!IsArma<ObjectType>::value || !IsSparseMat<ObjectType>::value
42
+ || !HasSerialize<ObjectType>::value, "mlpack can save Armadillo"
43
+ " matrices or a serialized mlpack model only; please use a known type.");
44
+ const bool isMatrixType = IsArma<ObjectType>::value ||
45
+ IsSparseMat<ObjectType>::value;
46
+ const bool isSerializable = HasSerialize<ObjectType>::value;
47
+ const bool isSparseMatrixType = IsSparseMat<ObjectType>::value;
48
+
49
+ bool success = DetectFileType<ObjectType>(filename, opts, false);
116
50
  if (!success)
117
51
  {
118
52
  Timer::Stop("saving_data");
119
53
  return false;
120
54
  }
121
55
 
56
+ const bool isImageFormat = (opts.Format() == FileType::PNG ||
57
+ opts.Format() == FileType::JPG || opts.Format() == FileType::PNM ||
58
+ opts.Format() == FileType::BMP || opts.Format() == FileType::GIF ||
59
+ opts.Format() == FileType::PSD || opts.Format() == FileType::TGA ||
60
+ opts.Format() == FileType::PIC || opts.Format() == FileType::ImageType);
61
+
122
62
  std::fstream stream;
123
- success = OpenFile(filename, opts, false, stream);
124
- if (!success)
63
+ if (!isImageFormat)
125
64
  {
126
- Timer::Stop("saving_data");
127
- return false;
65
+ success = OpenFile(filename, opts, false, stream);
66
+ if (!success)
67
+ {
68
+ Timer::Stop("saving_data");
69
+ return false;
70
+ }
128
71
  }
129
72
 
130
73
  // Try to save the file.
131
74
  Log::Info << "Saving " << opts.FileTypeToString() << " to '" << filename
132
75
  << "'." << std::endl;
133
- if constexpr (IsArma<MatType>::value || IsSparseMat<MatType>::value)
76
+ if constexpr (isMatrixType)
134
77
  {
135
- TextOptions txtOpts(std::move(opts));
136
- if constexpr (IsSparseMat<MatType>::value)
137
- {
138
- success = SaveSparse(matrix, txtOpts, filename, stream);
139
- }
140
- else if constexpr (IsCol<MatType>::value)
141
- {
142
- opts.NoTranspose() = true;
143
- success = SaveDense(matrix, txtOpts, filename, stream);
144
- }
145
- else if constexpr (IsRow<MatType>::value)
78
+ if (isImageFormat)
146
79
  {
147
- opts.NoTranspose() = false;
148
- success = SaveDense(matrix, txtOpts, filename, stream);
80
+ if constexpr (isSparseMatrixType)
81
+ {
82
+ arma::Mat<typename ObjectType::elem_type> tmp =
83
+ arma::conv_to<arma::Mat<
84
+ typename ObjectType::elem_type>>::from(matrix);
85
+ ImageOptions imgOpts(std::move(opts));
86
+ std::vector<std::string> files;
87
+ files.push_back(filename);
88
+ success = SaveImage(files, tmp, imgOpts);
89
+ opts = std::move(imgOpts);
90
+ }
91
+ else
92
+ {
93
+ ImageOptions imgOpts(std::move(opts));
94
+ std::vector<std::string> files;
95
+ files.push_back(filename);
96
+ success = SaveImage(files, matrix, imgOpts);
97
+ opts = std::move(imgOpts);
98
+ }
149
99
  }
150
- else if constexpr (IsDense<MatType>::value)
100
+ else
151
101
  {
152
- success = SaveDense(matrix, txtOpts, filename, stream);
102
+ success = SaveNumeric(filename, matrix, stream, opts);
153
103
  }
154
- opts = std::move(txtOpts);
104
+ }
105
+ else if constexpr (isSerializable)
106
+ {
107
+ success = SaveModel(matrix, opts, stream);
155
108
  }
156
109
  else
157
110
  {
158
- if (opts.Fatal())
159
- Log::Fatal << "DataOptionsType is unknown! Please use a known type or "
160
- << "or provide specific overloads." << std::endl;
161
- else
162
- Log::Warn << "DataOptionsType is unknown! Please use a known type or "
163
- << "or provide specific overloads." << std::endl;
164
-
165
- return false;
111
+ return HandleError("DataOptionsType is unknown! Please use a known type "
112
+ "or provide specific overloads.", opts);
166
113
  }
167
114
 
168
115
  if (!success)
169
116
  {
170
117
  Timer::Stop("saving_data");
171
- if (opts.Fatal())
172
- Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
173
- else
174
- Log::Warn << "Save to '" << filename << "' failed." << std::endl;
175
- return false;
118
+ std::stringstream oss;
119
+ oss << "Save to '" << filename << "' failed.";
120
+ return HandleError(oss, opts);
176
121
  }
177
122
 
178
123
  Timer::Stop("saving_data");
@@ -180,136 +125,6 @@ bool Save(const std::string& filename,
180
125
  return success;
181
126
  }
182
127
 
183
- template<typename eT>
184
- bool SaveDense(const arma::Mat<eT>& matrix,
185
- TextOptions& opts,
186
- const std::string& filename,
187
- std::fstream& stream)
188
- {
189
- bool success = false;
190
- arma::Mat<eT> tmp;
191
- // Transpose the matrix.
192
- if (!opts.NoTranspose())
193
- {
194
- tmp = trans(matrix);
195
- success = SaveMatrix(tmp, opts, filename, stream);
196
- }
197
- else
198
- success = SaveMatrix(matrix, opts, filename, stream);
199
-
200
- return success;
201
- }
202
-
203
- // Save a Sparse Matrix
204
- template<typename eT>
205
- bool SaveSparse(const arma::SpMat<eT>& matrix,
206
- TextOptions& opts,
207
- const std::string& filename,
208
- std::fstream& stream)
209
- {
210
- bool success = false;
211
- arma::SpMat<eT> tmp;
212
-
213
- // Transpose the matrix.
214
- if (!opts.NoTranspose())
215
- {
216
- arma::SpMat<eT> tmp = trans(matrix);
217
- success = SaveMatrix(tmp, opts, filename, stream);
218
- }
219
- else
220
- success = SaveMatrix(matrix, opts, filename, stream);
221
-
222
- return success;
223
- }
224
-
225
- //! Save a model to file.
226
- template<typename T>
227
- bool Save(const std::string& filename,
228
- const std::string& name,
229
- T& t,
230
- const bool fatal,
231
- format f,
232
- std::enable_if_t<HasSerialize<T>::value>*)
233
- {
234
- if (f == format::autodetect)
235
- {
236
- std::string extension = Extension(filename);
237
-
238
- if (extension == "xml")
239
- f = format::xml;
240
- else if (extension == "bin")
241
- f = format::binary;
242
- else if (extension == "json")
243
- f = format::json;
244
- else
245
- {
246
- if (fatal)
247
- Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
248
- << " extension? (allowed: xml/bin/json)" << std::endl;
249
- else
250
- Log::Warn << "Unable to detect type of '" << filename << "'; save "
251
- << "failed. Incorrect extension? (allowed: xml/bin/json)"
252
- << std::endl;
253
-
254
- return false;
255
- }
256
- }
257
-
258
- // Open the file to save to.
259
- std::ofstream ofs;
260
- #ifdef _WIN32
261
- if (f == format::binary) // Open non-text types in binary mode on Windows.
262
- ofs.open(filename, std::ofstream::out | std::ofstream::binary);
263
- else
264
- ofs.open(filename, std::ofstream::out);
265
- #else
266
- ofs.open(filename, std::ofstream::out);
267
- #endif
268
-
269
- if (!ofs.is_open())
270
- {
271
- if (fatal)
272
- Log::Fatal << "Unable to open file '" << filename << "' to save object '"
273
- << name << "'." << std::endl;
274
- else
275
- Log::Warn << "Unable to open file '" << filename << "' to save object '"
276
- << name << "'." << std::endl;
277
-
278
- return false;
279
- }
280
-
281
- try
282
- {
283
- if (f == format::xml)
284
- {
285
- cereal::XMLOutputArchive ar(ofs);
286
- ar(cereal::make_nvp(name.c_str(), t));
287
- }
288
- else if (f == format::json)
289
- {
290
- cereal::JSONOutputArchive ar(ofs);
291
- ar(cereal::make_nvp(name.c_str(), t));
292
- }
293
- else if (f == format::binary)
294
- {
295
- cereal::BinaryOutputArchive ar(ofs);
296
- ar(cereal::make_nvp(name.c_str(), t));
297
- }
298
-
299
- return true;
300
- }
301
- catch (cereal::Exception& e)
302
- {
303
- if (fatal)
304
- Log::Fatal << e.what() << std::endl;
305
- else
306
- Log::Warn << e.what() << std::endl;
307
-
308
- return false;
309
- }
310
- }
311
-
312
- } // namespace data
313
128
  } // namespace mlpack
314
129
 
315
130
  #endif
@@ -0,0 +1,45 @@
1
+ /**
2
+ * @file core/data/save_matrix.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ *
6
+ * Internal implementation of matrix save function.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_CORE_DATA_SAVE_MATRIX_HPP
14
+ #define MLPACK_CORE_DATA_SAVE_MATRIX_HPP
15
+
16
+ namespace mlpack {
17
+
18
+ template<typename MatType, typename DataOptionsType>
19
+ bool SaveMatrix(const MatType& matrix,
20
+ const DataOptionsType& opts,
21
+ #ifdef ARMA_USE_HDF5
22
+ const std::string& filename,
23
+ #else
24
+ const std::string& /* filename */,
25
+ #endif
26
+ std::fstream& stream)
27
+ {
28
+ bool success = false;
29
+ if (opts.Format() == FileType::HDF5Binary)
30
+ {
31
+ #ifdef ARMA_USE_HDF5
32
+ // We can't save with streams for HDF5.
33
+ success = matrix.save(filename, opts.ArmaFormat());
34
+ #endif
35
+ }
36
+ else
37
+ {
38
+ success = matrix.save(stream, opts.ArmaFormat());
39
+ }
40
+ return success;
41
+ }
42
+
43
+ } // namespace mlpack
44
+
45
+ #endif
@@ -0,0 +1,61 @@
1
+ /**
2
+ * @file core/data/save_model.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ *
6
+ * Internal implementation of model save function.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_CORE_DATA_SAVE_MODEL_HPP
14
+ #define MLPACK_CORE_DATA_SAVE_MODEL_HPP
15
+
16
+ #include <cereal/archives/xml.hpp>
17
+ #include <cereal/archives/binary.hpp>
18
+ #include <cereal/archives/json.hpp>
19
+
20
+ #include "text_options.hpp"
21
+
22
+ namespace mlpack {
23
+
24
+ template<typename Object>
25
+ bool SaveModel(Object& objectToSerialize,
26
+ const DataOptionsBase<PlainDataOptions>& opts,
27
+ std::fstream& stream)
28
+ {
29
+ try
30
+ {
31
+ if (opts.Format() == FileType::XML)
32
+ {
33
+ cereal::XMLOutputArchive ar(stream);
34
+ ar(cereal::make_nvp("model", objectToSerialize));
35
+ }
36
+ else if (opts.Format() == FileType::JSON)
37
+ {
38
+ cereal::JSONOutputArchive ar(stream);
39
+ ar(cereal::make_nvp("model", objectToSerialize));
40
+ }
41
+ else if (opts.Format() == FileType::BIN)
42
+ {
43
+ cereal::BinaryOutputArchive ar(stream);
44
+ ar(cereal::make_nvp("model", objectToSerialize));
45
+ }
46
+ return true;
47
+ }
48
+ catch (cereal::Exception& e)
49
+ {
50
+ if (opts.Fatal())
51
+ Log::Fatal << e.what() << std::endl;
52
+ else
53
+ Log::Warn << e.what() << std::endl;
54
+
55
+ return false;
56
+ }
57
+ }
58
+
59
+ } // namespace mlpack
60
+
61
+ #endif
@@ -0,0 +1,60 @@
1
+ /**
2
+ * @file core/data/save_numeric.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ *
6
+ * Internal implementation of numeric save function.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_CORE_DATA_SAVE_NUMERIC_HPP
14
+ #define MLPACK_CORE_DATA_SAVE_NUMERIC_HPP
15
+
16
+ #include "text_options.hpp"
17
+ #include "save_sparse.hpp"
18
+ #include "save_dense.hpp"
19
+
20
+ namespace mlpack {
21
+
22
+ template<typename ObjectType, typename DataOptionsType>
23
+ bool SaveNumeric(const std::string& filename,
24
+ const ObjectType& matrix,
25
+ std::fstream& stream,
26
+ DataOptionsBase<DataOptionsType>& opts)
27
+ {
28
+ bool success = false;
29
+
30
+ TextOptions txtOpts(std::move(opts));
31
+ if constexpr (IsSparseMat<ObjectType>::value)
32
+ {
33
+ success = SaveSparse(matrix, txtOpts, filename, stream);
34
+ }
35
+ else if constexpr (IsCol<ObjectType>::value)
36
+ {
37
+ const bool oldNoTranspose = txtOpts.NoTranspose();
38
+ txtOpts.NoTranspose() = true; // Force no transpose for a column.
39
+ success = SaveDense(matrix, txtOpts, filename, stream);
40
+ txtOpts.NoTranspose() = oldNoTranspose;
41
+ }
42
+ else if constexpr (IsRow<ObjectType>::value)
43
+ {
44
+ const bool oldNoTranspose = txtOpts.NoTranspose();
45
+ txtOpts.NoTranspose() = false; // Force transpose for a row.
46
+ success = SaveDense(matrix, txtOpts, filename, stream);
47
+ txtOpts.NoTranspose() = oldNoTranspose;
48
+ }
49
+ else if constexpr (IsDense<ObjectType>::value)
50
+ {
51
+ success = SaveDense(matrix, txtOpts, filename, stream);
52
+ }
53
+ static_cast<DataOptionsType&>(opts) = std::move(txtOpts);
54
+
55
+ return success;
56
+ }
57
+
58
+ } // namespace mlpack
59
+
60
+ #endif