mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -94,12 +94,12 @@ bool RPTreeMeanSplit<BoundType, MatType>::GetDotMedian(
94
94
  for (size_t k = 0; k < samples.n_elem; ++k)
95
95
  values[k] = dot(data.col(samples[k]), direction);
96
96
 
97
- const ElemType maximum = arma::max(values);
97
+ const ElemType maximum = max(values);
98
98
  const ElemType minimum = min(values);
99
99
  if (minimum == maximum)
100
100
  return false;
101
101
 
102
- splitVal = arma::median(values);
102
+ splitVal = median(values);
103
103
 
104
104
  if (splitVal == maximum)
105
105
  splitVal = minimum;
@@ -111,29 +111,29 @@ template<typename BoundType, typename MatType>
111
111
  bool RPTreeMeanSplit<BoundType, MatType>::GetMeanMedian(
112
112
  const MatType& data,
113
113
  const arma::uvec& samples,
114
- arma::Col<ElemType>& mean,
114
+ arma::Col<ElemType>& meanCol,
115
115
  ElemType& splitVal)
116
116
  {
117
117
  arma::Col<ElemType> values(samples.n_elem);
118
118
 
119
- mean = arma::mean(data.cols(samples), 1);
119
+ meanCol = mean(data.cols(samples), 1);
120
120
 
121
121
  arma::Col<ElemType> tmp(data.n_rows);
122
122
 
123
123
  for (size_t k = 0; k < samples.n_elem; ++k)
124
124
  {
125
125
  tmp = data.col(samples[k]);
126
- tmp -= mean;
126
+ tmp -= meanCol;
127
127
 
128
128
  values[k] = dot(tmp, tmp);
129
129
  }
130
130
 
131
- const ElemType maximum = arma::max(values);
131
+ const ElemType maximum = max(values);
132
132
  const ElemType minimum = min(values);
133
133
  if (minimum == maximum)
134
134
  return false;
135
135
 
136
- splitVal = arma::median(values);
136
+ splitVal = median(values);
137
137
 
138
138
  if (splitVal == maximum)
139
139
  splitVal = minimum;
@@ -143,10 +143,10 @@ class CellBound
143
143
  ElemType& MinWidth() { return minWidth; }
144
144
 
145
145
  //! Get the distance metric associated with this bound.
146
- [[deprecated("Will be removed in 5.0.0; use Distance()")]]
146
+ [[deprecated("Will be removed in mlpack 5.0.0; use Distance()")]]
147
147
  const DistanceType& Metric() const { return distance; }
148
148
  //! Modify the distance metric associated with this bound.
149
- [[deprecated("Will be removed in 5.0.0; use Distance()")]]
149
+ [[deprecated("Will be removed in mlpack 5.0.0; use Distance()")]]
150
150
  DistanceType& Metric() { return distance; }
151
151
 
152
152
  //! Get the distance metric associated with this bound.
@@ -33,7 +33,7 @@ inline CosineTree<MatType>::CosineTree(const MatType& dataset) :
33
33
  for (size_t i = 0; i < numColumns; ++i)
34
34
  {
35
35
  indices[i] = i;
36
- double l2Norm = (double) arma::norm(dataset.col(i), 2);
36
+ double l2Norm = (double) norm(dataset.col(i), 2);
37
37
  l2NormsSquared(i) = l2Norm * l2Norm;
38
38
  }
39
39
 
@@ -92,7 +92,7 @@ inline CosineTree<MatType>::CosineTree(const MatType& dataset,
92
92
 
93
93
  // Define root node of the tree and add it to the queue.
94
94
  CosineTree root(dataset);
95
- VecType tempVector = arma::zeros<VecType>(dataset.n_rows);
95
+ VecType tempVector = VecType(dataset.n_rows, GetFillType<VecType>::zeros);
96
96
  root.L2Error(-1.0); // We don't know what the error is.
97
97
  root.BasisVector(tempVector);
98
98
  treeQueue.push_back(&root);
@@ -412,8 +412,8 @@ inline void CosineTree<MatType>::ModifiedGramSchmidt(
412
412
  }
413
413
 
414
414
  // Normalize the modified centroid vector.
415
- if (arma::norm(newBasisVector, 2))
416
- newBasisVector /= arma::norm(newBasisVector, 2);
415
+ if (norm(newBasisVector, 2))
416
+ newBasisVector /= norm(newBasisVector, 2);
417
417
  }
418
418
 
419
419
  template<typename MatType>
@@ -475,7 +475,7 @@ inline double CosineTree<MatType>::MonteCarloError(
475
475
  }
476
476
 
477
477
  // Calculate the Frobenius norm squared of the projected vector.
478
- double frobProjection = arma::norm(projection, "frob");
478
+ double frobProjection = norm(projection, "frob");
479
479
  double frobProjectionSquared = frobProjection * frobProjection;
480
480
 
481
481
  // Calculate the weighted projection magnitude.
@@ -483,8 +483,8 @@ inline double CosineTree<MatType>::MonteCarloError(
483
483
  }
484
484
 
485
485
  // Compute mean and standard deviation of the weighted samples.
486
- double mu = arma::mean(weightedMagnitudes);
487
- double sigma = arma::stddev(weightedMagnitudes);
486
+ double mu = mean(weightedMagnitudes);
487
+ double sigma = stddev(weightedMagnitudes);
488
488
 
489
489
  if (!sigma)
490
490
  {
@@ -536,7 +536,7 @@ inline void CosineTree<MatType>::CosineNodeSplit()
536
536
 
537
537
  // Compute maximum and minimum cosine values.
538
538
  double cosineMax, cosineMin;
539
- cosineMax = arma::max(cosines % (cosines < 1));
539
+ cosineMax = max(cosines % (cosines < 1));
540
540
  cosineMin = min(cosines);
541
541
 
542
542
  std::vector<size_t> leftIndices, rightIndices;
@@ -670,8 +670,8 @@ inline void CosineTree<MatType>::CalculateCosines(
670
670
  else
671
671
  {
672
672
  cosines(i) =
673
- std::abs(arma::norm_dot(dataset->col(indices[splitPointIndex]),
674
- dataset->col(indices[i])));
673
+ std::abs(norm_dot(dataset->col(indices[splitPointIndex]),
674
+ dataset->col(indices[i])));
675
675
  }
676
676
  }
677
677
  }
@@ -402,6 +402,16 @@ class Octree
402
402
  const VecType& point,
403
403
  typename std::enable_if_t<IsVector<VecType>::value>* = 0) const;
404
404
 
405
+ //! Return the index of the beginning point of this subset.
406
+ size_t Begin() const { return begin; }
407
+ //! Modify the index of the beginning point of this subset.
408
+ size_t& Begin() { return begin; }
409
+
410
+ //! Return the number of points in this subset.
411
+ size_t Count() const { return count; }
412
+ //! Modify the number of points in this subset.
413
+ size_t& Count() { return count; }
414
+
405
415
  //! Store the center of the bounding region in the given vector.
406
416
  template<typename VecType>
407
417
  void Center(VecType& center) const { bound.Center(center); }
@@ -288,8 +288,13 @@ Octree<DistanceType, StatisticType, MatType>::Octree(
288
288
  // Calculate empirical center of data.
289
289
  bound |= dataset->cols(begin, begin + count - 1);
290
290
 
291
- // Now split the node.
292
- SplitNode(center, width, maxLeafSize);
291
+ ElemType maxWidth = 0.0;
292
+ for (size_t i = 0; i < bound.Dim(); ++i)
293
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
294
+ maxWidth = bound[i].Hi() - bound[i].Lo();
295
+
296
+ if (maxWidth != 0.0)
297
+ SplitNode(center, width, maxLeafSize);
293
298
 
294
299
  // Calculate the distance from the empirical center of this node to the
295
300
  // empirical center of the parent.
@@ -323,8 +328,13 @@ Octree<DistanceType, StatisticType, MatType>::Octree(
323
328
  // Calculate empirical center of data.
324
329
  bound |= dataset->cols(begin, begin + count - 1);
325
330
 
326
- // Now split the node.
327
- SplitNode(center, width, oldFromNew, maxLeafSize);
331
+ ElemType maxWidth = 0.0;
332
+ for (size_t i = 0; i < bound.Dim(); ++i)
333
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
334
+ maxWidth = bound[i].Hi() - bound[i].Lo();
335
+
336
+ if (maxWidth != 0.0)
337
+ SplitNode(center, width, oldFromNew, maxLeafSize);
328
338
 
329
339
  // Calculate the distance from the empirical center of this node to the
330
340
  // empirical center of the parent.
@@ -12,6 +12,16 @@
12
12
  #ifndef MLPACK_CORE_UTIL_ARMA_TRAITS_HPP
13
13
  #define MLPACK_CORE_UTIL_ARMA_TRAITS_HPP
14
14
 
15
+ // Get whether or not the given type is any non-field Armadillo type
16
+ // This includes sparse, dense, and cube types
17
+ template<typename T>
18
+ struct IsArma
19
+ {
20
+ constexpr static bool value = arma::is_arma_type<T>::value ||
21
+ arma::is_arma_cube_type<T>::value ||
22
+ arma::is_arma_sparse_type<T>::value;
23
+ };
24
+
15
25
  // Structs have public members by default (that's why they are chosen over
16
26
  // classes).
17
27
 
@@ -154,6 +164,15 @@ struct GetRowType<arma::SpMat<eT>>
154
164
  using type = arma::SpRow<eT>;
155
165
  };
156
166
 
167
+ template<typename MatType, typename T = void>
168
+ struct GetURowType;
169
+
170
+ template<typename MatType>
171
+ struct GetURowType<MatType, std::enable_if_t<IsArma<MatType>::value>>
172
+ {
173
+ using type = arma::Row<arma::uword>;
174
+ };
175
+
157
176
  // Get the column vector type corresponding to a given MatType.
158
177
 
159
178
  template<typename MatType>
@@ -162,8 +181,11 @@ struct GetColType
162
181
  using type = arma::Col<typename MatType::elem_type>;
163
182
  };
164
183
 
184
+ template<typename MatType, typename T = void>
185
+ struct GetUColType;
186
+
165
187
  template<typename MatType>
166
- struct GetUColType
188
+ struct GetUColType<MatType, std::enable_if_t<IsArma<MatType>::value>>
167
189
  {
168
190
  using type = arma::Col<arma::uword>;
169
191
  };
@@ -239,16 +261,6 @@ struct GetCubeType<arma::Mat<eT>>
239
261
  using type = arma::Cube<eT>;
240
262
  };
241
263
 
242
- #if defined(MLPACK_HAS_COOT)
243
-
244
- template<typename eT>
245
- struct GetCubeType<coot::Mat<eT>>
246
- {
247
- using type = coot::Cube<eT>;
248
- };
249
-
250
- #endif
251
-
252
264
  // Get the sparse matrix type corresponding to a given MatType.
253
265
 
254
266
  template<typename MatType>
@@ -356,35 +368,10 @@ struct IsDense<arma::Mat<eT>>
356
368
  constexpr static bool value = true;
357
369
  };
358
370
 
359
- // Get whether or not the given type is any non-field Armadillo type
360
- // This includes sparse, dense, and cube types
361
371
  template<typename T>
362
- struct IsArma
372
+ struct IsSparse
363
373
  {
364
- constexpr static bool value = arma::is_arma_type<T>::value ||
365
- arma::is_arma_cube_type<T>::value ||
366
- arma::is_arma_sparse_type<T>::value;
374
+ constexpr static bool value = arma::is_arma_sparse_type<T>::value;
367
375
  };
368
376
 
369
- #if defined(MLPACK_HAS_COOT)
370
-
371
- // Get whether or not the given type is any Bandicoot type
372
- // This includes dense and cube types
373
- template<typename T>
374
- struct IsCoot
375
- {
376
- constexpr static bool value = coot::is_coot_type<T>::value ||
377
- coot::is_coot_cube_type<T>::value;
378
- };
379
-
380
- #else
381
-
382
- template<typename T>
383
- struct IsCoot
384
- {
385
- constexpr static bool value = false;
386
- };
387
-
388
- #endif
389
-
390
377
  #endif
@@ -0,0 +1,97 @@
1
+ /**
2
+ * @file core/util/coot_traits.hpp
3
+ * @author Ryan Curtin
4
+ * @author Omar Shrit
5
+ *
6
+ * Some traits used for template metaprogramming (SFINAE) with Bandicoot types.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_CORE_UTIL_COOT_TRAITS_HPP
14
+ #define MLPACK_CORE_UTIL_COOT_TRAITS_HPP
15
+
16
+ #if defined(MLPACK_HAS_COOT)
17
+
18
+ // Get whether or not the given type is any Bandicoot type
19
+ // This includes dense and cube types
20
+ template<typename T>
21
+ struct IsCoot
22
+ {
23
+ constexpr static bool value = coot::is_coot_type<T>::value ||
24
+ coot::is_coot_cube_type<T>::value;
25
+ };
26
+
27
+ template<typename eT>
28
+ struct GetCubeType<coot::Mat<eT>>
29
+ {
30
+ using type = coot::Cube<eT>;
31
+ };
32
+
33
+ template<typename eT>
34
+ struct GetDenseMatType<coot::Cube<eT>>
35
+ {
36
+ using type = coot::Mat<eT>;
37
+ };
38
+
39
+ template<typename MatType>
40
+ struct GetURowType<MatType, std::enable_if_t<IsCoot<MatType>::value>>
41
+ {
42
+ using type = coot::Row<coot::uword>;
43
+ };
44
+
45
+ template<typename MatType>
46
+ struct GetUColType<MatType, std::enable_if_t<IsCoot<MatType>::value>>
47
+ {
48
+ using type = coot::Col<coot::uword>;
49
+ };
50
+
51
+ template<typename eT>
52
+ struct IsVector<coot::Col<eT> >
53
+ {
54
+ static const bool value = true;
55
+ };
56
+
57
+ template<typename eT>
58
+ struct IsVector<coot::Row<eT> >
59
+ {
60
+ static const bool value = true;
61
+ };
62
+
63
+ template<typename eT>
64
+ struct IsVector<coot::subview_col<eT> >
65
+ {
66
+ static const bool value = true;
67
+ };
68
+
69
+ template<typename eT>
70
+ struct IsVector<coot::subview_row<eT> >
71
+ {
72
+ static const bool value = true;
73
+ };
74
+
75
+ template<typename eT>
76
+ struct IsMatrix<coot::Mat<eT> >
77
+ {
78
+ static const bool value = true;
79
+ };
80
+
81
+ template<typename eT>
82
+ struct IsCube<coot::Cube<eT> >
83
+ {
84
+ static const bool value = true;
85
+ };
86
+
87
+ #else
88
+
89
+ template<typename T>
90
+ struct IsCoot
91
+ {
92
+ constexpr static bool value = false;
93
+ };
94
+
95
+ #endif // defined(MLPACK_HAS_COOT)
96
+
97
+ #endif
@@ -30,7 +30,6 @@ class Timers;
30
30
  #include "params.hpp"
31
31
 
32
32
  namespace mlpack {
33
- namespace data {
34
33
 
35
34
  class IncrementPolicy;
36
35
 
@@ -44,7 +43,6 @@ using DatasetInfo = DatasetMapper<IncrementPolicy, std::string>;
44
43
  // DatasetInfo.
45
44
  void CheckCategoricalParam(util::Params& p, const std::string& paramName);
46
45
 
47
- } // namespace data
48
46
  } // namespace mlpack
49
47
 
50
48
  #endif
@@ -757,10 +757,10 @@
757
757
  * here---it will cause problems.
758
758
  * @param ALIAS One-character string representing the alias of the parameter.
759
759
  */
760
- #define TUPLE_TYPE std::tuple<mlpack::data::DatasetInfo, arma::mat>
760
+ #define TUPLE_TYPE std::tuple<mlpack::DatasetInfo, arma::mat>
761
761
  #define PARAM_MATRIX_AND_INFO_IN(ID, DESC, ALIAS) \
762
762
  PARAM(TUPLE_TYPE, ID, DESC, ALIAS, \
763
- "std::tuple<mlpack::data::DatasetInfo, arma::mat>", false, true, true, \
763
+ "std::tuple<mlpack::DatasetInfo, arma::mat>", false, true, true, \
764
764
  TUPLE_TYPE())
765
765
 
766
766
  /**
@@ -789,10 +789,10 @@
789
789
  * here---it will cause problems.
790
790
  * @param ALIAS One-character string representing the alias of the parameter.
791
791
  */
792
- #define TUPLE_TYPE std::tuple<mlpack::data::DatasetInfo, arma::mat>
792
+ #define TUPLE_TYPE std::tuple<mlpack::DatasetInfo, arma::mat>
793
793
  #define PARAM_MATRIX_AND_INFO_IN_REQ(ID, DESC, ALIAS) \
794
794
  PARAM(TUPLE_TYPE, ID, DESC, ALIAS, \
795
- "std::tuple<mlpack::data::DatasetInfo, arma::mat>", true, true, true, \
795
+ "std::tuple<mlpack::DatasetInfo, arma::mat>", true, true, true, \
796
796
  TUPLE_TYPE())
797
797
 
798
798
  /**
@@ -286,11 +286,11 @@ inline void Params::CheckInputMatrices()
286
286
  {
287
287
  CheckInputMatrix(Get<arma::rowvec>(paramName), paramName);
288
288
  }
289
- else if (paramType == "std::tuple<mlpack::data::DatasetInfo, arma::mat>")
289
+ else if (paramType == "std::tuple<mlpack::DatasetInfo, arma::mat>")
290
290
  {
291
291
  // Note that CheckCategoricalParam() is a utility function that must be
292
292
  // defined after DatasetInfo is fully defined.
293
- data::CheckCategoricalParam(*this, paramName);
293
+ CheckCategoricalParam(*this, paramName);
294
294
  }
295
295
  }
296
296
  }
@@ -18,16 +18,24 @@
18
18
  #define MLPACK_CORE_UTIL_USING_HPP
19
19
 
20
20
  #include "arma_traits.hpp"
21
+ #include "coot_traits.hpp"
21
22
 
22
23
  namespace mlpack {
23
24
 
24
25
  #ifdef MLPACK_HAS_COOT
25
26
 
26
27
  /* using for bandicoot namespace*/
27
- using coot::exp;
28
+ using coot::accu;
29
+ using coot::all;
30
+ using coot::conv_to;
28
31
  using coot::dot;
32
+ using coot::exp;
33
+ using coot::find;
34
+ using coot::find_nan;
35
+ using coot::find_nonfinite;
29
36
  using coot::join_cols;
30
37
  using coot::join_rows;
38
+ using coot::linspace;
31
39
  using coot::log;
32
40
  using coot::min;
33
41
  using coot::max;
@@ -41,21 +49,38 @@ using coot::randn;
41
49
  using coot::randu;
42
50
  using coot::repmat;
43
51
  using coot::sign;
52
+ using coot::size;
53
+ using coot::sort_index;
44
54
  using coot::sqrt;
45
55
  using coot::square;
46
56
  using coot::sum;
47
57
  using coot::trans;
48
58
  using coot::vectorise;
49
59
  using coot::zeros;
60
+ #else
61
+
62
+ // Only use arma::conv_to if Bandicoot is not available: Bandicoot's conv_to
63
+ // supports Armadillo types too.
64
+ using arma::conv_to;
50
65
 
51
66
  #endif
52
67
 
53
68
  /* using for armadillo namespace */
54
- using arma::exp;
69
+ using arma::accu;
70
+ using arma::all;
55
71
  using arma::dot;
72
+ using arma::exp;
73
+ using arma::find;
74
+ #if ARMA_VERSION_MAJOR > 11 || \
75
+ (ARMA_VERSION_MAJOR == 11 && ARMA_VERSION_MINOR >= 4)
76
+ using arma::find_nan;
77
+ #endif
78
+ using arma::find_nonfinite;
56
79
  using arma::join_cols;
57
80
  using arma::join_rows;
81
+ using arma::linspace;
58
82
  using arma::log;
83
+ using arma::linspace;
59
84
  using arma::min;
60
85
  using arma::max;
61
86
  using arma::mean;
@@ -68,6 +93,8 @@ using arma::randn;
68
93
  using arma::randu;
69
94
  using arma::repmat;
70
95
  using arma::sign;
96
+ using arma::size;
97
+ using arma::sort_index;
71
98
  using arma::sqrt;
72
99
  using arma::square;
73
100
  using arma::sum;
@@ -15,10 +15,12 @@
15
15
  #include <string>
16
16
 
17
17
  // The version of mlpack. If this is a git repository, this will be a version
18
- // with higher number than the most recent release.
18
+ // with higher number than the most recent release, and the MLPACK_PRERELEASE
19
+ // macro will be defined.
19
20
  #define MLPACK_VERSION_MAJOR 4
20
- #define MLPACK_VERSION_MINOR 6
21
- #define MLPACK_VERSION_PATCH 2
21
+ #define MLPACK_VERSION_MINOR 7
22
+ #define MLPACK_VERSION_PATCH 0
23
+ //#define MLPACK_PRERELEASE
22
24
 
23
25
  // The name of the version (for use by --version).
24
26
  namespace mlpack {
@@ -20,16 +20,13 @@ namespace util {
20
20
  // name.
21
21
  inline std::string GetVersion()
22
22
  {
23
- #ifndef MLPACK_GIT_VERSION
24
23
  std::stringstream o;
25
24
  o << "mlpack " << MLPACK_VERSION_MAJOR << "." << MLPACK_VERSION_MINOR
26
25
  << "." << MLPACK_VERSION_PATCH;
26
+ #if defined(MLPACK_PRERELEASE)
27
+ o << " (prerelease)";
28
+ #endif
27
29
  return o.str();
28
- #else
29
- // This file is generated by CMake as necessary and contains just a return
30
- // statement with the git revision in it.
31
- #include "gitversion.hpp"
32
- #endif
33
30
  }
34
31
 
35
32
  } // namespace util
@@ -84,7 +84,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
84
84
  timers.Stop("adaboost_classification");
85
85
 
86
86
  Row<size_t> results;
87
- data::RevertLabels(predictedLabels, m->Mappings(), results);
87
+ RevertLabels(predictedLabels, m->Mappings(), results);
88
88
 
89
89
  params.Get<arma::Row<size_t>>("predictions") = std::move(results);
90
90
  }
@@ -108,7 +108,7 @@ BINDING_EXAMPLE(
108
108
  BINDING_SEE_ALSO("AdaBoost on Wikipedia", "https://en.wikipedia.org/wiki/"
109
109
  "AdaBoost");
110
110
  BINDING_SEE_ALSO("Improved boosting algorithms using confidence-rated "
111
- "predictions (pdf)", "http://rob.schapire.net/papers/SchapireSi98.pdf");
111
+ "predictions (pdf)", "http://www.schapire.net/papers/SchapireSi98.pdf");
112
112
  BINDING_SEE_ALSO("Perceptron", "#perceptron");
113
113
  BINDING_SEE_ALSO("Decision Trees", "#decision_tree");
114
114
  BINDING_SEE_ALSO("AdaBoost C++ class documentation",
@@ -202,7 +202,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
202
202
  Row<size_t> labels;
203
203
 
204
204
  // Normalize the labels.
205
- data::NormalizeLabels(labelsIn, labels, m->Mappings());
205
+ NormalizeLabels(labelsIn, labels, m->Mappings());
206
206
 
207
207
  // Get other training parameters.
208
208
  const double tolerance = params.Get<double>("tolerance");
@@ -253,7 +253,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
253
253
  }
254
254
 
255
255
  Row<size_t> results;
256
- data::RevertLabels(predictedLabels, m->Mappings(), results);
256
+ RevertLabels(predictedLabels, m->Mappings(), results);
257
257
 
258
258
  // Save the predicted labels.
259
259
  if (params.Has("predictions"))
@@ -84,7 +84,7 @@ BINDING_EXAMPLE(
84
84
  BINDING_SEE_ALSO("AdaBoost on Wikipedia", "https://en.wikipedia.org/wiki/"
85
85
  "AdaBoost");
86
86
  BINDING_SEE_ALSO("Improved boosting algorithms using confidence-rated "
87
- "predictions (pdf)", "http://rob.schapire.net/papers/SchapireSi98.pdf");
87
+ "predictions (pdf)", "http://www.schapire.net/papers/SchapireSi98.pdf");
88
88
  BINDING_SEE_ALSO("Perceptron", "#perceptron");
89
89
  BINDING_SEE_ALSO("Decision Trees", "#decision_tree");
90
90
  BINDING_SEE_ALSO("AdaBoost C++ class documentation",
@@ -157,7 +157,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
157
157
  Row<size_t> labels;
158
158
 
159
159
  // Normalize the labels.
160
- data::NormalizeLabels(labelsIn, labels, m->Mappings());
160
+ NormalizeLabels(labelsIn, labels, m->Mappings());
161
161
 
162
162
  // Get other training parameters.
163
163
  const double tolerance = params.Get<double>("tolerance");
@@ -35,6 +35,7 @@
35
35
  #include "elliot_function.hpp"
36
36
  #include "gaussian_function.hpp"
37
37
  #include "gelu_function.hpp"
38
+ #include "gelu_exact_function.hpp"
38
39
  #include "hard_sigmoid_function.hpp"
39
40
  #include "hard_swish_function.hpp"
40
41
  #include "identity_function.hpp"
@@ -34,7 +34,8 @@ class BipolarSigmoidFunction
34
34
  * @param x Input data.
35
35
  * @return f(x).
36
36
  */
37
- static double Fn(const double x)
37
+ template<typename ElemType>
38
+ static ElemType Fn(const ElemType x)
38
39
  {
39
40
  return (1 - std::exp(-x)) / (1 + std::exp(-x));
40
41
  }
@@ -58,9 +59,10 @@ class BipolarSigmoidFunction
58
59
  * @param y Result of Fn(x).
59
60
  * @return f'(x)
60
61
  */
61
- static double Deriv(const double /* x */, const double y)
62
+ template<typename ElemType>
63
+ static ElemType Deriv(const ElemType /* x */, const ElemType y)
62
64
  {
63
- return (1.0 - std::pow(y, 2)) / 2.0;
65
+ return (1 - std::pow(y, ElemType(2))) / 2;
64
66
  }
65
67
 
66
68
  /**
@@ -75,7 +77,7 @@ class BipolarSigmoidFunction
75
77
  const OutputVecType& y,
76
78
  DerivVecType& dy)
77
79
  {
78
- dy = (1.0 - pow(y, 2)) / 2.0;
80
+ dy = (1 - square(y)) / 2;
79
81
  }
80
82
  }; // class BipolarSigmoidFunction
81
83