mlpack 4.6.1__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (426) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +25 -16
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +53 -43
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +194 -57
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +130 -315
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/ccov.hpp +1 -0
  120. mlpack/include/mlpack/core/math/ccov_impl.hpp +4 -5
  121. mlpack/include/mlpack/core/math/make_alias.hpp +100 -3
  122. mlpack/include/mlpack/core/math/random.hpp +19 -5
  123. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  124. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  125. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  126. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  127. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  128. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  129. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  130. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  131. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -21
  132. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  133. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  134. mlpack/include/mlpack/core/util/param.hpp +4 -4
  135. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  136. mlpack/include/mlpack/core/util/sfinae_utility.hpp +24 -2
  137. mlpack/include/mlpack/core/util/using.hpp +29 -2
  138. mlpack/include/mlpack/core/util/version.hpp +5 -3
  139. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  140. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  141. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  142. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  143. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  144. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  145. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  146. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  147. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  149. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  150. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  151. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  152. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  153. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  154. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  155. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  156. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  157. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  158. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  159. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  160. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  161. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  162. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  163. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  165. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  166. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  167. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  168. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  169. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  171. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  172. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  173. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  174. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  175. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  176. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +2 -3
  177. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  178. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  179. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  180. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  181. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  184. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/network_init.hpp +5 -5
  186. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  187. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  188. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  189. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  190. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  191. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  192. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  193. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  194. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  195. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  196. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  197. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  198. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  199. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  200. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  201. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +19 -20
  202. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +68 -68
  203. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  204. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  205. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  206. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  207. mlpack/include/mlpack/methods/ann/layer/concat.hpp +19 -18
  208. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +19 -20
  209. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  210. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  211. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  212. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  213. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  214. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  215. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  216. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  217. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  218. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  219. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  220. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  221. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  222. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  223. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  224. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  225. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  226. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  227. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  228. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  229. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  230. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  231. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  232. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  233. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  234. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  235. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  236. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  237. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  238. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  239. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  240. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +19 -18
  241. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +29 -32
  242. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  243. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  244. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  245. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  246. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  247. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  248. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  249. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  250. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  251. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +24 -23
  252. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +28 -27
  253. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +27 -26
  254. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +30 -31
  255. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  256. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  257. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +32 -27
  258. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +185 -89
  259. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +29 -25
  260. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +38 -39
  261. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  262. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/padding.hpp +22 -17
  264. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +45 -32
  265. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  266. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  267. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  268. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  269. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +16 -2
  270. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  271. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  272. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  273. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  274. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  275. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  276. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  277. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  278. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  279. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  280. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  281. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +8 -7
  282. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  284. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  285. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  286. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  287. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  288. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  289. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  290. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  291. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  292. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  293. mlpack/include/mlpack/methods/ann/rnn.hpp +145 -50
  294. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +245 -53
  295. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  296. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_impl.hpp +3 -8
  297. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  298. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  299. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  300. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  301. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  302. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  303. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  304. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  305. mlpack/include/mlpack/methods/decision_tree/fitness_functions/gini_gain.hpp +5 -8
  306. mlpack/include/mlpack/methods/decision_tree/fitness_functions/information_gain.hpp +5 -8
  307. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  308. mlpack/include/mlpack/methods/gmm/diagonal_gmm_impl.hpp +2 -1
  309. mlpack/include/mlpack/methods/gmm/eigenvalue_ratio_constraint.hpp +3 -3
  310. mlpack/include/mlpack/methods/gmm/gmm_impl.hpp +2 -1
  311. mlpack/include/mlpack/methods/hmm/hmm_impl.hpp +10 -5
  312. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  313. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  314. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  315. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  316. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  317. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  318. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  319. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  320. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  321. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  322. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  323. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  324. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  325. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  326. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  327. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  328. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  329. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  330. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  331. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  332. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  333. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  334. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  335. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  336. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  337. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  338. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  339. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  340. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  341. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  342. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  343. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  344. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  345. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  346. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  347. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  348. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  349. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  350. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +61 -41
  351. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +77 -67
  352. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  353. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  354. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  355. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  356. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  357. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  358. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  359. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  360. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  361. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  362. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  363. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  364. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  365. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  366. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  367. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  368. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  369. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  370. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  371. mlpack/include/mlpack/prereqs.hpp +1 -0
  372. mlpack/kde.cp313-win_amd64.pyd +0 -0
  373. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  374. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  375. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  376. mlpack/knn.cp313-win_amd64.pyd +0 -0
  377. mlpack/krann.cp313-win_amd64.pyd +0 -0
  378. mlpack/lars.cp313-win_amd64.pyd +0 -0
  379. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  380. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  381. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  382. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  383. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  384. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  385. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  386. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  387. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  388. mlpack/nca.cp313-win_amd64.pyd +0 -0
  389. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  390. mlpack/pca.cp313-win_amd64.pyd +0 -0
  391. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  392. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  393. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  394. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  395. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  396. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  397. mlpack/radical.cp313-win_amd64.pyd +0 -0
  398. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  399. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  400. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  401. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  402. {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  403. {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/RECORD +407 -388
  404. {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  405. mlpack/include/mlpack/core/data/format.hpp +0 -31
  406. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  407. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  408. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  409. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  410. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  411. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  412. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  413. mlpack/include/mlpack/core/data/types.hpp +0 -61
  414. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  415. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  416. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  417. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  418. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  419. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  420. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  421. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  422. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  423. mlpack-4.6.1.dist-info/DELVEWHEEL +0 -2
  424. {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  425. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  426. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
mlpack/__init__.py CHANGED
@@ -11,14 +11,14 @@ http://www.opensource.org/licenses/BSD-3-Clause for more information.
11
11
 
12
12
 
13
13
  # start delvewheel patch
14
- def _delvewheel_patch_1_10_1():
14
+ def _delvewheel_patch_1_12_0():
15
15
  import os
16
16
  if os.path.isdir(libs_dir := os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'mlpack.libs'))):
17
17
  os.add_dll_directory(libs_dir)
18
18
 
19
19
 
20
- _delvewheel_patch_1_10_1()
21
- del _delvewheel_patch_1_10_1
20
+ _delvewheel_patch_1_12_0()
21
+ del _delvewheel_patch_1_12_0
22
22
  # end delvewheel patch
23
23
 
24
24
  import warnings
@@ -74,4 +74,4 @@ from .adaboost import *
74
74
  from .linear_regression_train import linear_regression_train
75
75
  from .linear_regression_predict import linear_regression_predict
76
76
  from .linear_regression import *
77
- __version__='4.6.1'
77
+ __version__='4.7.0'
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -86,6 +86,7 @@
86
86
  #include <armadillo>
87
87
  #include <mlpack/core/util/arma_traits.hpp>
88
88
  #include <mlpack/core/util/omp_reductions.hpp>
89
+ #include <mlpack/core/arma_extend/find_nan.hpp>
89
90
 
90
91
  // On Visual Studio, disable C4519 (default arguments for function templates)
91
92
  // since it's by default an error, which doesn't even make any sense because
@@ -0,0 +1,63 @@
1
+ /**
2
+ * @file find_nan.hpp
3
+ * @author Ryan Curtin
4
+ *
5
+ * When find_nan() is not available (Armadillo < 11.4), provide an internal
6
+ * mlpack implementation that operates the same way. It is slower.
7
+ */
8
+ #ifndef MLPACK_CORE_ARMA_EXTEND_FIND_NAN_HPP
9
+ #define MLPACK_CORE_ARMA_EXTEND_FIND_NAN_HPP
10
+
11
+ namespace mlpack {
12
+
13
+ #if ARMA_VERSION_MAJOR < 11 || \
14
+ (ARMA_VERSION_MAJOR == 11 && ARMA_VERSION_MINOR < 4)
15
+
16
+ template<typename T>
17
+ arma::uvec find_nan(const T& m,
18
+ const std::enable_if_t<arma::is_arma_type<T>::value>* = 0)
19
+ {
20
+ typedef typename T::elem_type ElemType;
21
+
22
+ if (!std::numeric_limits<ElemType>::has_quiet_NaN)
23
+ return arma::uvec(); // There can't be any NaNs.
24
+
25
+ // find_nonfinite() exists on older Armadillo, and we can also search for +Inf
26
+ // and -Inf.
27
+ arma::uvec nonfiniteIndices = arma::find_nonfinite(m);
28
+ if (nonfiniteIndices.n_elem == 0)
29
+ return arma::uvec();
30
+
31
+ arma::uvec infIndices = arma::find(
32
+ m == std::numeric_limits<ElemType>::infinity());
33
+ arma::uvec neginfIndices = arma::find(
34
+ m == -std::numeric_limits<ElemType>::infinity());
35
+
36
+ arma::uvec result(nonfiniteIndices.n_elem -
37
+ (infIndices.n_elem + neginfIndices.n_elem));
38
+ if (result.n_elem == 0)
39
+ return result;
40
+
41
+ size_t infIndex = 0;
42
+ size_t neginfIndex = 0;
43
+ size_t outputIndex = 0;
44
+ for (size_t i = 0; i < nonfiniteIndices.n_elem; ++i)
45
+ {
46
+ if (infIndex < infIndices.n_elem &&
47
+ nonfiniteIndices[i] == infIndices[infIndex])
48
+ ++infIndex;
49
+ else if (neginfIndex < neginfIndices.n_elem &&
50
+ nonfiniteIndices[i] == neginfIndices[neginfIndex])
51
+ ++neginfIndex;
52
+ else
53
+ result[outputIndex++] = nonfiniteIndices[i];
54
+ }
55
+
56
+ return result;
57
+ }
58
+
59
+ #endif
60
+
61
+ } // namespace mlpack
62
+
63
+ #endif
@@ -0,0 +1,48 @@
1
+ /**
2
+ * @file core/cereal/low_precision.hpp
3
+ * @author Ryan Curtin
4
+ *
5
+ * Extra shims necessary for cereal to serialize to JSON for low-precision types
6
+ * (e.g. FP16, BF16, etc.).
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_CORE_CEREAL_LOW_PRECISION_HPP
14
+ #define MLPACK_CORE_CEREAL_LOW_PRECISION_HPP
15
+
16
+ namespace cereal {
17
+
18
+ // Because our serialization is always done with name-value pairs, we can catch
19
+ // any FP16 serialization at the NVP level with a specialized implementation of
20
+ // the load and save functions for the JSON archive (the only one that does not
21
+ // serialize low-precision correctly).
22
+
23
+ #if defined(ARMA_HAVE_FP16)
24
+
25
+ inline void CEREAL_SAVE_FUNCTION_NAME(JSONOutputArchive &ar,
26
+ NameValuePair<arma::fp16&> const& t)
27
+ {
28
+ ar.setNextName(t.name);
29
+ std::ostringstream oss;
30
+ oss.precision(std::numeric_limits<arma::fp16>::max_digits10);
31
+ oss << t.value;
32
+ ar(oss.str());
33
+ }
34
+
35
+ inline void CEREAL_LOAD_FUNCTION_NAME(JSONInputArchive& ar,
36
+ NameValuePair<arma::fp16&>& t)
37
+ {
38
+ ar.setNextName(t.name);
39
+ std::string encoded;
40
+ ar.loadValue(encoded);
41
+ t.value = arma::fp16(std::stof(encoded));
42
+ }
43
+
44
+ #endif
45
+
46
+ } // namespace cereal
47
+
48
+ #endif
@@ -57,12 +57,12 @@ class CVBase
57
57
 
58
58
  /**
59
59
  * Assert that MLAlgorithm takes the numClasses parameter and a
60
- * data::DatasetInfo parameter and store them.
60
+ * DatasetInfo parameter and store them.
61
61
  *
62
62
  * @param datasetInfo Type information for each dimension of the dataset.
63
63
  * @param numClasses Number of classes in the dataset.
64
64
  */
65
- CVBase(const data::DatasetInfo& datasetInfo,
65
+ CVBase(const DatasetInfo& datasetInfo,
66
66
  const size_t numClasses);
67
67
 
68
68
  /**
@@ -101,9 +101,9 @@ class CVBase
101
101
  static_assert(MIE::IsSupported,
102
102
  "The given MLAlgorithm is not supported by MetaInfoExtractor");
103
103
 
104
- //! A variable for storing a data::DatasetInfo parameter if it is passed.
105
- const data::DatasetInfo datasetInfo;
106
- //! An indicator whether a data::DatasetInfo parameter has been passed.
104
+ //! A variable for storing a DatasetInfo parameter if it is passed.
105
+ const DatasetInfo datasetInfo;
106
+ //! An indicator whether a DatasetInfo parameter has been passed.
107
107
  const bool isDatasetInfoPassed;
108
108
  //! A variable for storing the numClasses parameter if it is passed.
109
109
  size_t numClasses;
@@ -145,7 +145,7 @@ class CVBase
145
145
 
146
146
  /**
147
147
  * Construct a trained MLAlgorithm model if MLAlgorithm takes the
148
- * numClasses parameter and a data::DatasetInfo parameter.
148
+ * numClasses parameter and a DatasetInfo parameter.
149
149
  */
150
150
  template<typename... MLAlgorithmArgs,
151
151
  bool Enabled = MIE::TakesNumClasses & MIE::TakesDatasetInfo,
@@ -183,7 +183,7 @@ class CVBase
183
183
 
184
184
  /**
185
185
  * Construct a trained MLAlgorithm model if MLAlgorithm takes the
186
- * numClasses parameter and a data::DatasetInfo parameter.
186
+ * numClasses parameter and a DatasetInfo parameter.
187
187
  */
188
188
  template<typename... MLAlgorithmArgs,
189
189
  bool Enabled = MIE::TakesNumClasses & MIE::TakesDatasetInfo,
@@ -196,13 +196,13 @@ class CVBase
196
196
  const MLAlgorithmArgs&... args);
197
197
 
198
198
  /**
199
- * When MLAlgorithm supports a data::DatasetInfo parameter, training should be
199
+ * When MLAlgorithm supports a DatasetInfo parameter, training should be
200
200
  * treated separately - there are models that can be constructed with and
201
201
  * without a data:DatasetInfo parameter and models that can be constructed
202
- * only with a data::DatasetInfo parameter.
202
+ * only with a DatasetInfo parameter.
203
203
  *
204
204
  * Construct a trained MLAlgorithm model when it can be constructed without a
205
- * data::DatasetInfo parameter.
205
+ * DatasetInfo parameter.
206
206
  */
207
207
  template<bool ConstructableWithoutDatasetInfo,
208
208
  typename... MLAlgorithmArgs,
@@ -213,7 +213,7 @@ class CVBase
213
213
 
214
214
  /**
215
215
  * Construct a trained MLAlgorithm model when it can't be constructed without
216
- * a data::DatasetInfo parameter.
216
+ * a DatasetInfo parameter.
217
217
  */
218
218
  template<bool ConstructableWithoutDatasetInfo,
219
219
  typename... MLAlgorithmArgs,
@@ -54,7 +54,7 @@ template<typename MLAlgorithm,
54
54
  CVBase<MLAlgorithm,
55
55
  MatType,
56
56
  PredictionsType,
57
- WeightsType>::CVBase(const data::DatasetInfo& datasetInfo,
57
+ WeightsType>::CVBase(const DatasetInfo& datasetInfo,
58
58
  const size_t numClasses) :
59
59
  datasetInfo(datasetInfo),
60
60
  isDatasetInfoPassed(true),
@@ -63,7 +63,7 @@ CVBase<MLAlgorithm,
63
63
  static_assert(MIE::TakesNumClasses,
64
64
  "The given MLAlgorithm does not take the numClasses parameter");
65
65
  static_assert(MIE::TakesDatasetInfo,
66
- "The given MLAlgorithm does not accept a data::DatasetInfo parameter");
66
+ "The given MLAlgorithm does not accept a DatasetInfo parameter");
67
67
  }
68
68
 
69
69
  template<typename MLAlgorithm,
@@ -184,9 +184,9 @@ MLAlgorithm CVBase<MLAlgorithm,
184
184
  {
185
185
  static_assert(
186
186
  std::is_constructible_v<MLAlgorithm, const MatType&,
187
- const data::DatasetInfo, const PredictionsType&, const size_t,
187
+ const DatasetInfo, const PredictionsType&, const size_t,
188
188
  MLAlgorithmArgs...>,
189
- "The given MLAlgorithm is not constructible with a data::DatasetInfo "
189
+ "The given MLAlgorithm is not constructible with a DatasetInfo "
190
190
  "parameter and the passed arguments");
191
191
 
192
192
  static const bool constructableWithoutDatasetInfo =
@@ -256,9 +256,9 @@ MLAlgorithm CVBase<MLAlgorithm,
256
256
  {
257
257
  static_assert(
258
258
  std::is_constructible_v<MLAlgorithm, const MatType&,
259
- const data::DatasetInfo, const PredictionsType&, const size_t,
259
+ const DatasetInfo, const PredictionsType&, const size_t,
260
260
  const WeightsType&, MLAlgorithmArgs...>,
261
- "The given MLAlgorithm is not constructible with a data::DatasetInfo "
261
+ "The given MLAlgorithm is not constructible with a DatasetInfo "
262
262
  "parameter and the passed arguments");
263
263
 
264
264
  static const bool constructableWithoutDatasetInfo =
@@ -302,7 +302,7 @@ MLAlgorithm CVBase<MLAlgorithm,
302
302
  {
303
303
  if (!isDatasetInfoPassed)
304
304
  throw std::invalid_argument(
305
- "The given MLAlgorithm requires a data::DatasetInfo parameter");
305
+ "The given MLAlgorithm requires a DatasetInfo parameter");
306
306
 
307
307
  return MLAlgorithm(xs, datasetInfo, ys, numClasses, args...);
308
308
  }
@@ -14,6 +14,7 @@
14
14
 
15
15
  #include <mlpack/core/cv/meta_info_extractor.hpp>
16
16
  #include <mlpack/core/cv/cv_base.hpp>
17
+ #include <mlpack/core/util/arma_traits.hpp>
17
18
 
18
19
  namespace mlpack {
19
20
 
@@ -96,7 +97,7 @@ class KFoldCV
96
97
 
97
98
  /**
98
99
  * This constructor can be used for multiclass classification algorithms that
99
- * can take a data::DatasetInfo parameter.
100
+ * can take a DatasetInfo parameter.
100
101
  *
101
102
  * @param k Number of folds (should be at least 2).
102
103
  * @param xs Data points to cross-validate on.
@@ -107,7 +108,7 @@ class KFoldCV
107
108
  */
108
109
  KFoldCV(const size_t k,
109
110
  const MatType& xs,
110
- const data::DatasetInfo& datasetInfo,
111
+ const DatasetInfo& datasetInfo,
111
112
  const PredictionsType& ys,
112
113
  const size_t numClasses,
113
114
  const bool shuffle = true);
@@ -149,7 +150,7 @@ class KFoldCV
149
150
 
150
151
  /**
151
152
  * This constructor can be used for multiclass classification algorithms that
152
- * can take a data::DatasetInfo parameter and support weighted learning.
153
+ * can take a DatasetInfo parameter and support weighted learning.
153
154
  *
154
155
  * @param k Number of folds (should be at least 2).
155
156
  * @param xs Data points to cross-validate on.
@@ -161,7 +162,7 @@ class KFoldCV
161
162
  */
162
163
  KFoldCV(const size_t k,
163
164
  const MatType& xs,
164
- const data::DatasetInfo& datasetInfo,
165
+ const DatasetInfo& datasetInfo,
165
166
  const PredictionsType& ys,
166
167
  const size_t numClasses,
167
168
  const WeightsType& weights,
@@ -280,30 +281,38 @@ class KFoldCV
280
281
  /**
281
282
  * Get the ith training subset from a variable of a matrix type.
282
283
  */
283
- template<typename ElementType>
284
- inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m,
285
- const size_t i);
284
+ template<typename SubsetMatType>
285
+ inline SubsetMatType GetTrainingSubset(
286
+ SubsetMatType& m,
287
+ const size_t i,
288
+ const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>* = 0);
286
289
 
287
290
  /**
288
291
  * Get the ith training subset from a variable of a row type.
289
292
  */
290
- template<typename ElementType>
291
- inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r,
292
- const size_t i);
293
+ template<typename SubsetRowType>
294
+ inline SubsetRowType GetTrainingSubset(
295
+ SubsetRowType& r,
296
+ const size_t i,
297
+ const typename std::enable_if_t<IsRow<SubsetRowType>::value>* = 0);
293
298
 
294
299
  /**
295
300
  * Get the ith validation subset from a variable of a matrix type.
296
301
  */
297
- template<typename ElementType>
298
- inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m,
299
- const size_t i);
302
+ template<typename SubsetMatType>
303
+ inline SubsetMatType GetValidationSubset(
304
+ SubsetMatType& m,
305
+ const size_t i,
306
+ const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>* = 0);
300
307
 
301
308
  /**
302
309
  * Get the ith validation subset from a variable of a row type.
303
310
  */
304
- template<typename ElementType>
305
- inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r,
306
- const size_t i);
311
+ template<typename SubsetRowType>
312
+ inline SubsetRowType GetValidationSubset(
313
+ SubsetRowType& r,
314
+ const size_t i,
315
+ const typename std::enable_if_t<IsRow<SubsetRowType>::value>* = 0);
307
316
  };
308
317
 
309
318
  } // namespace mlpack
@@ -58,7 +58,7 @@ KFoldCV<MLAlgorithm,
58
58
  PredictionsType,
59
59
  WeightsType>::KFoldCV(const size_t k,
60
60
  const MatType& xs,
61
- const data::DatasetInfo& datasetInfo,
61
+ const DatasetInfo& datasetInfo,
62
62
  const PredictionsType& ys,
63
63
  const size_t numClasses,
64
64
  const bool shuffle) :
@@ -111,7 +111,7 @@ KFoldCV<MLAlgorithm,
111
111
  PredictionsType,
112
112
  WeightsType>::KFoldCV(const size_t k,
113
113
  const MatType& xs,
114
- const data::DatasetInfo& datasetInfo,
114
+ const DatasetInfo& datasetInfo,
115
115
  const PredictionsType& ys,
116
116
  const size_t numClasses,
117
117
  const WeightsType& weights,
@@ -270,7 +270,7 @@ double KFoldCV<MLAlgorithm,
270
270
  return 0.0;
271
271
  }
272
272
 
273
- return arma::mean(evaluations.elem(arma::find_finite(evaluations)));
273
+ return mean(evaluations.elem(find_finite(evaluations)));
274
274
  }
275
275
 
276
276
  template<typename MLAlgorithm,
@@ -300,7 +300,7 @@ double KFoldCV<MLAlgorithm,
300
300
  modelPtr.reset(new MLAlgorithm(std::move(model)));
301
301
  }
302
302
 
303
- return arma::mean(evaluations);
303
+ return mean(evaluations);
304
304
  }
305
305
 
306
306
  template<typename MLAlgorithm,
@@ -375,14 +375,15 @@ template<typename MLAlgorithm,
375
375
  typename MatType,
376
376
  typename PredictionsType,
377
377
  typename WeightsType>
378
- template<typename ElementType>
379
- arma::Mat<ElementType> KFoldCV<MLAlgorithm,
380
- Metric,
381
- MatType,
382
- PredictionsType,
383
- WeightsType>::GetTrainingSubset(
384
- arma::Mat<ElementType>& m,
385
- const size_t i)
378
+ template<typename SubsetMatType>
379
+ SubsetMatType KFoldCV<MLAlgorithm,
380
+ Metric,
381
+ MatType,
382
+ PredictionsType,
383
+ WeightsType>::GetTrainingSubset(
384
+ SubsetMatType& m,
385
+ const size_t i,
386
+ const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>*)
386
387
  {
387
388
  // If this is not the first fold, we have to handle it a little bit
388
389
  // differently, since the last fold may contain slightly more than 'binSize'
@@ -390,8 +391,9 @@ arma::Mat<ElementType> KFoldCV<MLAlgorithm,
390
391
  const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
391
392
  (k - 1) * binSize;
392
393
 
393
- return arma::Mat<ElementType>(m.colptr(binSize * i), m.n_rows, subsetSize,
394
- false, true);
394
+ SubsetMatType alias;
395
+ MakeAlias(alias, m, m.n_rows, subsetSize, m.n_rows * binSize * i);
396
+ return alias;
395
397
  }
396
398
 
397
399
  template<typename MLAlgorithm,
@@ -399,14 +401,15 @@ template<typename MLAlgorithm,
399
401
  typename MatType,
400
402
  typename PredictionsType,
401
403
  typename WeightsType>
402
- template<typename ElementType>
403
- arma::Row<ElementType> KFoldCV<MLAlgorithm,
404
- Metric,
405
- MatType,
406
- PredictionsType,
407
- WeightsType>::GetTrainingSubset(
408
- arma::Row<ElementType>& r,
409
- const size_t i)
404
+ template<typename SubsetRowType>
405
+ SubsetRowType KFoldCV<MLAlgorithm,
406
+ Metric,
407
+ MatType,
408
+ PredictionsType,
409
+ WeightsType>::GetTrainingSubset(
410
+ SubsetRowType& r,
411
+ const size_t i,
412
+ const typename std::enable_if_t<IsRow<SubsetRowType>::value>*)
410
413
  {
411
414
  // If this is not the first fold, we have to handle it a little bit
412
415
  // differently, since the last fold may contain slightly more than 'binSize'
@@ -414,7 +417,9 @@ arma::Row<ElementType> KFoldCV<MLAlgorithm,
414
417
  const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
415
418
  (k - 1) * binSize;
416
419
 
417
- return arma::Row<ElementType>(r.colptr(binSize * i), subsetSize, false, true);
420
+ SubsetRowType alias;
421
+ MakeAlias(alias, r, subsetSize, r.n_rows * binSize * i);
422
+ return alias;
418
423
  }
419
424
 
420
425
  template<typename MLAlgorithm,
@@ -422,18 +427,21 @@ template<typename MLAlgorithm,
422
427
  typename MatType,
423
428
  typename PredictionsType,
424
429
  typename WeightsType>
425
- template<typename ElementType>
426
- arma::Mat<ElementType> KFoldCV<MLAlgorithm,
427
- Metric,
428
- MatType,
429
- PredictionsType,
430
- WeightsType>::GetValidationSubset(
431
- arma::Mat<ElementType>& m,
432
- const size_t i)
430
+ template<typename SubsetMatType>
431
+ SubsetMatType KFoldCV<MLAlgorithm,
432
+ Metric,
433
+ MatType,
434
+ PredictionsType,
435
+ WeightsType>::GetValidationSubset(
436
+ SubsetMatType& m,
437
+ const size_t i,
438
+ const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>*)
433
439
  {
434
440
  const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
435
- return arma::Mat<ElementType>(m.colptr(ValidationSubsetFirstCol(i)), m.n_rows,
436
- subsetSize, false, true);
441
+ SubsetMatType alias;
442
+ MakeAlias(alias, m, m.n_rows, subsetSize,
443
+ m.n_rows * ValidationSubsetFirstCol(i));
444
+ return alias;
437
445
  }
438
446
 
439
447
  template<typename MLAlgorithm,
@@ -441,18 +449,20 @@ template<typename MLAlgorithm,
441
449
  typename MatType,
442
450
  typename PredictionsType,
443
451
  typename WeightsType>
444
- template<typename ElementType>
445
- arma::Row<ElementType> KFoldCV<MLAlgorithm,
446
- Metric,
447
- MatType,
448
- PredictionsType,
449
- WeightsType>::GetValidationSubset(
450
- arma::Row<ElementType>& r,
451
- const size_t i)
452
+ template<typename SubsetRowType>
453
+ SubsetRowType KFoldCV<MLAlgorithm,
454
+ Metric,
455
+ MatType,
456
+ PredictionsType,
457
+ WeightsType>::GetValidationSubset(
458
+ SubsetRowType& r,
459
+ const size_t i,
460
+ const typename std::enable_if_t<IsRow<SubsetRowType>::value>*)
452
461
  {
453
462
  const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
454
- return arma::Row<ElementType>(r.colptr(ValidationSubsetFirstCol(i)),
455
- subsetSize, false, true);
463
+ SubsetRowType alias;
464
+ MakeAlias(alias, r, subsetSize, r.n_rows * ValidationSubsetFirstCol(i));
465
+ return alias;
456
466
  }
457
467
 
458
468
  } // namespace mlpack