mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -21,7 +21,7 @@
21
21
  namespace mlpack {
22
22
 
23
23
  template<typename MatType>
24
- PReLUType<MatType>::PReLUType(const double userAlpha) :
24
+ PReLU<MatType>::PReLU(const double userAlpha) :
25
25
  Layer<MatType>(),
26
26
  userAlpha(userAlpha)
27
27
  {
@@ -29,8 +29,8 @@ PReLUType<MatType>::PReLUType(const double userAlpha) :
29
29
  }
30
30
 
31
31
  template<typename MatType>
32
- PReLUType<MatType>::PReLUType(
33
- const PReLUType& other) :
32
+ PReLU<MatType>::PReLU(
33
+ const PReLU& other) :
34
34
  Layer<MatType>(other),
35
35
  userAlpha(other.userAlpha)
36
36
  {
@@ -38,8 +38,8 @@ PReLUType<MatType>::PReLUType(
38
38
  }
39
39
 
40
40
  template<typename MatType>
41
- PReLUType<MatType>::PReLUType(
42
- PReLUType&& other) :
41
+ PReLU<MatType>::PReLU(
42
+ PReLU&& other) :
43
43
  Layer<MatType>(std::move(other)),
44
44
  userAlpha(std::move(other.userAlpha))
45
45
  {
@@ -47,8 +47,8 @@ PReLUType<MatType>::PReLUType(
47
47
  }
48
48
 
49
49
  template<typename MatType>
50
- PReLUType<MatType>&
51
- PReLUType<MatType>::operator=(const PReLUType& other)
50
+ PReLU<MatType>&
51
+ PReLU<MatType>::operator=(const PReLU& other)
52
52
  {
53
53
  if (&other != this)
54
54
  {
@@ -60,8 +60,8 @@ PReLUType<MatType>::operator=(const PReLUType& other)
60
60
  }
61
61
 
62
62
  template<typename MatType>
63
- PReLUType<MatType>&
64
- PReLUType<MatType>::operator=(PReLUType&& other)
63
+ PReLU<MatType>&
64
+ PReLU<MatType>::operator=(PReLU&& other)
65
65
  {
66
66
  if (&other != this)
67
67
  {
@@ -73,27 +73,27 @@ PReLUType<MatType>::operator=(PReLUType&& other)
73
73
  }
74
74
 
75
75
  template<typename MatType>
76
- void PReLUType<MatType>::SetWeights(const MatType& weightsIn)
76
+ void PReLU<MatType>::SetWeights(const MatType& weightsIn)
77
77
  {
78
78
  MakeAlias(alpha, weightsIn, 1, 1);
79
79
  }
80
80
 
81
81
  template<typename MatType>
82
- void PReLUType<MatType>::CustomInitialize(
82
+ void PReLU<MatType>::CustomInitialize(
83
83
  MatType& W,
84
84
  const size_t elements)
85
85
  {
86
86
  if (elements != 1)
87
87
  {
88
- throw std::invalid_argument("PReLUType::CustomInitialize(): wrong "
88
+ throw std::invalid_argument("PReLU::CustomInitialize(): wrong "
89
89
  "elements size!");
90
90
  }
91
91
 
92
- W(0) = userAlpha;
92
+ W(0) = ElemType(userAlpha);
93
93
  }
94
94
 
95
95
  template<typename MatType>
96
- void PReLUType<MatType>::Forward(
96
+ void PReLU<MatType>::Forward(
97
97
  const MatType& input, MatType& output)
98
98
  {
99
99
  output = input;
@@ -103,14 +103,14 @@ void PReLUType<MatType>::Forward(
103
103
  }
104
104
 
105
105
  template<typename MatType>
106
- void PReLUType<MatType>::Backward(
106
+ void PReLU<MatType>::Backward(
107
107
  const MatType& input,
108
108
  const MatType& /* output */,
109
109
  const MatType& gy,
110
110
  MatType& g)
111
111
  {
112
112
  MatType derivative;
113
- derivative.set_size(arma::size(input));
113
+ derivative.set_size(size(input));
114
114
  #pragma omp for
115
115
  for (size_t i = 0; i < input.n_elem; ++i)
116
116
  derivative(i) = (input(i) >= 0) ? 1 : alpha(0);
@@ -119,7 +119,7 @@ void PReLUType<MatType>::Backward(
119
119
  }
120
120
 
121
121
  template<typename MatType>
122
- void PReLUType<MatType>::Gradient(
122
+ void PReLU<MatType>::Gradient(
123
123
  const MatType& input,
124
124
  const MatType& error,
125
125
  MatType& gradient)
@@ -131,7 +131,7 @@ void PReLUType<MatType>::Gradient(
131
131
 
132
132
  template<typename MatType>
133
133
  template<typename Archive>
134
- void PReLUType<MatType>::serialize(
134
+ void PReLU<MatType>::serialize(
135
135
  Archive& ar,
136
136
  const uint32_t /* version */)
137
137
  {
@@ -21,7 +21,7 @@
21
21
  namespace mlpack {
22
22
 
23
23
  /**
24
- * Implementation of the Radial Basis Function layer. The RBFType class, when
24
+ * Implementation of the Radial Basis Function layer. The RBF class, when
25
25
  * used with a non-linear activation function, acts as a Radial Basis Function
26
26
  * which can be used with a feed-forward neural network.
27
27
  *
@@ -45,11 +45,14 @@ template <
45
45
  typename MatType = arma::mat,
46
46
  typename Activation = GaussianFunction
47
47
  >
48
- class RBFType : public Layer<MatType>
48
+ class RBF : public Layer<MatType>
49
49
  {
50
50
  public:
51
- //! Create the RBFType object.
52
- RBFType();
51
+ // Convenience typedef to access the element type of the weights and data.
52
+ using ElemType = typename MatType::elem_type;
53
+
54
+ // Create the RBF object.
55
+ RBF();
53
56
 
54
57
  /**
55
58
  * Create the Radial Basis Function layer object using the specified
@@ -59,24 +62,36 @@ class RBFType : public Layer<MatType>
59
62
  * @param centres The centres calculated using k-means of data.
60
63
  * @param betas The beta value to be used with centres.
61
64
  */
62
- RBFType(const size_t outSize,
63
- MatType& centres,
64
- double betas = 0);
65
+ RBF(const size_t outSize,
66
+ const MatType& centres,
67
+ double betas = 0);
65
68
 
66
- //! Clone the LinearType object. This handles polymorphism correctly.
67
- RBFType* Clone() const { return new RBFType(*this); }
69
+ /**
70
+ * Create the Radial Basis Function layer object using the specified
71
+ * parameters.
72
+ *
73
+ * @param outSize The number of output units.
74
+ * @param centres The centres calculated using k-means of data.
75
+ * @param betas The beta value to be used with centres.
76
+ */
77
+ RBF(const size_t outSize,
78
+ MatType&& centres,
79
+ double betas = 0);
80
+
81
+ // Clone the RBF object. This handles polymorphism correctly.
82
+ RBF* Clone() const { return new RBF(*this); }
68
83
 
69
84
  // Virtual destructor.
70
- virtual ~RBFType() { }
85
+ virtual ~RBF() { }
71
86
 
72
- //! Copy the given RBFType layer.
73
- RBFType(const RBFType& other);
74
- //! Take ownership of the given RBFType layer.
75
- RBFType(RBFType&& other);
76
- //! Copy the given RBFType layer.
77
- RBFType& operator=(const RBFType& other);
78
- //! Take ownership of the given RBFType layer.
79
- RBFType& operator=(RBFType&& other);
87
+ // Copy the given RBF layer.
88
+ RBF(const RBF& other);
89
+ // Take ownership of the given RBF layer.
90
+ RBF(RBF&& other);
91
+ // Copy the given RBF layer.
92
+ RBF& operator=(const RBF& other);
93
+ // Take ownership of the given RBF layer.
94
+ RBF& operator=(RBF&& other);
80
95
 
81
96
  /**
82
97
  * Ordinary feed forward pass of the radial basis function.
@@ -94,11 +109,11 @@ class RBFType : public Layer<MatType>
94
109
  const MatType& /* gy */,
95
110
  MatType& /* g */);
96
111
 
97
- //! Compute the output dimensions of the layer given `InputDimensions()`. The
98
- //! RBFType layer flattens the input.
112
+ // Compute the output dimensions of the layer given `InputDimensions()`. The
113
+ // RBF layer flattens the input.
99
114
  void ComputeOutputDimensions();
100
115
 
101
- //! Get the size of the weights.
116
+ // Get the size of the weights.
102
117
  size_t WeightSize() const { return 0; }
103
118
 
104
119
  /**
@@ -108,20 +123,18 @@ class RBFType : public Layer<MatType>
108
123
  void serialize(Archive& ar, const uint32_t /* version */);
109
124
 
110
125
  private:
111
- //! Locally-stored number of output units.
126
+ // Locally-stored number of output units.
112
127
  size_t outSize;
113
128
 
114
- //! Locally-stored the betas values.
129
+ // Locally-stored the betas values.
115
130
  double betas;
116
131
 
117
- //! Locally-stored the learnable centre of the shape.
132
+ // Locally-stored the learnable centre of the shape.
118
133
  MatType centres;
119
134
 
120
- //! Locally-stored the output distances of the shape.
135
+ // Locally-stored the output distances of the shape.
121
136
  MatType distances;
122
- }; // class RBFType
123
-
124
- using RBF = RBFType<arma::mat>;
137
+ }; // class RBF
125
138
 
126
139
  } // namespace mlpack
127
140
 
@@ -16,7 +16,7 @@
16
16
  namespace mlpack {
17
17
 
18
18
  template<typename MatType, typename Activation>
19
- RBFType<MatType, Activation>::RBFType() :
19
+ RBF<MatType, Activation>::RBF() :
20
20
  Layer<MatType>(),
21
21
  outSize(0),
22
22
  betas(0)
@@ -25,9 +25,9 @@ RBFType<MatType, Activation>::RBFType() :
25
25
  }
26
26
 
27
27
  template<typename MatType, typename Activation>
28
- RBFType<MatType, Activation>::RBFType(
28
+ RBF<MatType, Activation>::RBF(
29
29
  const size_t outSize,
30
- MatType& centres,
30
+ const MatType& centres,
31
31
  double betas) :
32
32
  Layer<MatType>(),
33
33
  outSize(outSize),
@@ -41,7 +41,32 @@ RBFType<MatType, Activation>::RBFType(
41
41
  {
42
42
  double maxDis = 0;
43
43
  MatType temp = centres.each_col() - centres.col(i);
44
- maxDis = max(pow(sum(pow((temp), 2), 0), 0.5).t());
44
+ maxDis = max(sqrt(sum(square(temp), 0)).t());
45
+ if (maxDis > sigmas)
46
+ sigmas = maxDis;
47
+ }
48
+ this->betas = std::pow(2 * outSize, 0.5) / sigmas;
49
+ }
50
+ }
51
+
52
+ template<typename MatType, typename Activation>
53
+ RBF<MatType, Activation>::RBF(
54
+ const size_t outSize,
55
+ MatType&& centres,
56
+ double betas) :
57
+ Layer<MatType>(),
58
+ outSize(outSize),
59
+ betas(betas),
60
+ centres(std::move(centres))
61
+ {
62
+ double sigmas = 0;
63
+ if (betas == 0)
64
+ {
65
+ for (size_t i = 0; i < centres.n_cols; i++)
66
+ {
67
+ double maxDis = 0;
68
+ MatType temp = centres.each_col() - centres.col(i);
69
+ maxDis = max(sqrt(sum(square(temp), 0)).t());
45
70
  if (maxDis > sigmas)
46
71
  sigmas = maxDis;
47
72
  }
@@ -51,7 +76,7 @@ RBFType<MatType, Activation>::RBFType(
51
76
 
52
77
  template<typename MatType,
53
78
  typename Activation>
54
- RBFType<MatType, Activation>::RBFType(const RBFType& other) :
79
+ RBF<MatType, Activation>::RBF(const RBF& other) :
55
80
  Layer<MatType>(other),
56
81
  outSize(other.outSize),
57
82
  betas(other.betas),
@@ -62,7 +87,7 @@ RBFType<MatType, Activation>::RBFType(const RBFType& other) :
62
87
 
63
88
  template<typename MatType,
64
89
  typename Activation>
65
- RBFType<MatType, Activation>::RBFType(RBFType&& other) :
90
+ RBF<MatType, Activation>::RBF(RBF&& other) :
66
91
  Layer<MatType>(other),
67
92
  outSize(other.outSize),
68
93
  betas(other.betas),
@@ -72,8 +97,8 @@ RBFType<MatType, Activation>::RBFType(RBFType&& other) :
72
97
  }
73
98
 
74
99
  template<typename MatType, typename Activation>
75
- RBFType<MatType, Activation>&
76
- RBFType<MatType, Activation>::operator=(const RBFType& other)
100
+ RBF<MatType, Activation>&
101
+ RBF<MatType, Activation>::operator=(const RBF& other)
77
102
  {
78
103
  if (&other != this)
79
104
  {
@@ -87,8 +112,8 @@ RBFType<MatType, Activation>::operator=(const RBFType& other)
87
112
  }
88
113
 
89
114
  template<typename MatType, typename Activation>
90
- RBFType<MatType, Activation>&
91
- RBFType<MatType, Activation>::operator=(RBFType&& other)
115
+ RBF<MatType, Activation>&
116
+ RBF<MatType, Activation>::operator=(RBF&& other)
92
117
  {
93
118
  if (&other != this)
94
119
  {
@@ -102,14 +127,14 @@ RBFType<MatType, Activation>::operator=(RBFType&& other)
102
127
  }
103
128
 
104
129
  template<typename MatType, typename Activation>
105
- void RBFType<MatType, Activation>::Forward(
130
+ void RBF<MatType, Activation>::Forward(
106
131
  const MatType& input,
107
132
  MatType& output)
108
133
  {
109
134
  // Sanity check: make sure the dimensions are right.
110
135
  if (input.n_rows != centres.n_rows)
111
136
  {
112
- Log::Fatal << "RBFType::Forward(): input size (" << input.n_rows << ") does"
137
+ Log::Fatal << "RBF::Forward(): input size (" << input.n_rows << ") does"
113
138
  << " not match given center size (" << centres.n_rows << ")!"
114
139
  << std::endl;
115
140
  }
@@ -119,14 +144,14 @@ void RBFType<MatType, Activation>::Forward(
119
144
  for (size_t i = 0; i < input.n_cols; i++)
120
145
  {
121
146
  MatType temp = centres.each_col() - input.col(i);
122
- distances.col(i) = pow(sum(pow((temp), 2), 0), 0.5).t();
147
+ distances.col(i) = sqrt(sum(square(temp), 0)).t();
123
148
  }
124
- Activation::Fn(distances * std::pow(betas, 0.5), output);
149
+ Activation::Fn(distances * ElemType(std::pow(betas, 0.5)), output);
125
150
  }
126
151
 
127
152
 
128
153
  template<typename MatType, typename Activation>
129
- void RBFType<MatType, Activation>::Backward(
154
+ void RBF<MatType, Activation>::Backward(
130
155
  const MatType& /* input */,
131
156
  const MatType& /* output */,
132
157
  const MatType& /* gy */,
@@ -136,7 +161,7 @@ void RBFType<MatType, Activation>::Backward(
136
161
  }
137
162
 
138
163
  template<typename MatType, typename Activation>
139
- void RBFType<MatType, Activation>::ComputeOutputDimensions()
164
+ void RBF<MatType, Activation>::ComputeOutputDimensions()
140
165
  {
141
166
  this->outputDimensions = std::vector<size_t>(this->inputDimensions.size(), 1);
142
167
 
@@ -146,7 +171,7 @@ void RBFType<MatType, Activation>::ComputeOutputDimensions()
146
171
 
147
172
  template<typename MatType, typename Activation>
148
173
  template<typename Archive>
149
- void RBFType<MatType, Activation>::serialize(
174
+ void RBF<MatType, Activation>::serialize(
150
175
  Archive& ar,
151
176
  const uint32_t /* version */)
152
177
  {
@@ -58,7 +58,10 @@ template<typename MatType = arma::mat>
58
58
  class RecurrentLayer : public Layer<MatType>
59
59
  {
60
60
  public:
61
+ // Convenience typedefs.
62
+ using ElemType = typename MatType::elem_type;
61
63
  using CubeType = typename GetCubeType<MatType>::type;
64
+
62
65
  /**
63
66
  * Create the RecurrentLayer.
64
67
  */
@@ -124,6 +127,16 @@ class RecurrentLayer : public Layer<MatType>
124
127
  // meant to be done by the enclosing network.)
125
128
  void CurrentStep(const size_t& step, const bool end = false);
126
129
 
130
+ /**
131
+ * Update the internal state of the layer when the step changes. This is
132
+ * meant to be called by the enclosing network. A child recurrent class
133
+ * should override this.
134
+ */
135
+ virtual void OnStepChanged(const size_t /* step */,
136
+ const size_t /* batchSize */,
137
+ const size_t /* activeBatchSize */,
138
+ const bool /* backwards */) { }
139
+
127
140
  // Get the previous step. This is a very simple function but can lead to
128
141
  // slightly more readable code in Forward(), Backward(), and Gradient()
129
142
  // implementations.
@@ -14,7 +14,7 @@
14
14
  * url = {https://arxiv.org/pdf/1704.04861}
15
15
  * }
16
16
  * @endcode
17
- *
17
+ *
18
18
  * mlpack is free software; you may redistribute it and/or modify it under the
19
19
  * terms of the 3-clause BSD license. You should have received a copy of the
20
20
  * 3-clause BSD license along with mlpack. If not, see
@@ -33,28 +33,31 @@ namespace mlpack {
33
33
  * (Default: arma::mat).
34
34
  */
35
35
  template<typename MatType = arma::mat>
36
- class ReLU6Type : public Layer<MatType>
36
+ class ReLU6 : public Layer<MatType>
37
37
  {
38
38
  public:
39
+ // Convenience typedef to access the element type of the weights and data.
40
+ using ElemType = typename MatType::elem_type;
41
+
39
42
  /**
40
- * Create the ReLU6Type object.
43
+ * Create the ReLU6 object.
41
44
  */
42
- ReLU6Type();
45
+ ReLU6();
43
46
 
44
- //! Clone the ReLU6Type object. This handles polymorphism correctly.
45
- ReLU6Type* Clone() const { return new ReLU6Type(*this); }
47
+ // Clone the ReLU6 object. This handles polymorphism correctly.
48
+ ReLU6* Clone() const { return new ReLU6(*this); }
46
49
 
47
50
  // Virtual destructor.
48
- virtual ~ReLU6Type() { }
51
+ virtual ~ReLU6() { }
49
52
 
50
- //! Copy the given ReLU6Type.
51
- ReLU6Type(const ReLU6Type& other);
52
- //! Take ownership of the given ReLU6Type.
53
- ReLU6Type(ReLU6Type&& other);
54
- //! Copy the given ReLU6Type.
55
- ReLU6Type& operator=(const ReLU6Type& other);
56
- //! Take ownership of the given ReLU6Type.
57
- ReLU6Type& operator=(ReLU6Type&& other);
53
+ // Copy the given ReLU6.
54
+ ReLU6(const ReLU6& other);
55
+ // Take ownership of the given ReLU6.
56
+ ReLU6(ReLU6&& other);
57
+ // Copy the given ReLU6.
58
+ ReLU6& operator=(const ReLU6& other);
59
+ // Take ownership of the given ReLU6.
60
+ ReLU6& operator=(ReLU6&& other);
58
61
 
59
62
  /**
60
63
  * Ordinary feed forward pass of a neural network, evaluating the function
@@ -80,7 +83,7 @@ class ReLU6Type : public Layer<MatType>
80
83
  const MatType& gy,
81
84
  MatType& g);
82
85
 
83
- //! Get size of weights.
86
+ // Get size of weights.
84
87
  size_t WeightSize() const { return 0; }
85
88
 
86
89
  /**
@@ -90,11 +93,6 @@ class ReLU6Type : public Layer<MatType>
90
93
  void serialize(Archive& /* ar */, const uint32_t /* version */);
91
94
  }; // class ReLU6
92
95
 
93
- // Convenience typedefs.
94
-
95
- // Standard ReLU6 layer.
96
- using ReLU6 = ReLU6Type<arma::mat>;
97
-
98
96
  } // namespace mlpack
99
97
 
100
98
  // Include implementation.
@@ -29,31 +29,31 @@
29
29
  namespace mlpack {
30
30
 
31
31
  template<typename MatType>
32
- ReLU6Type<MatType>::ReLU6Type() :
32
+ ReLU6<MatType>::ReLU6() :
33
33
  Layer<MatType>()
34
34
  {
35
35
  // Nothing to do here.
36
36
  }
37
37
 
38
38
  template<typename MatType>
39
- ReLU6Type<MatType>::ReLU6Type(
40
- const ReLU6Type& other) :
39
+ ReLU6<MatType>::ReLU6(
40
+ const ReLU6& other) :
41
41
  Layer<MatType>(other)
42
42
  {
43
43
  // Nothing to do here.
44
44
  }
45
45
 
46
46
  template<typename MatType>
47
- ReLU6Type<MatType>::ReLU6Type(
48
- ReLU6Type&& other) :
47
+ ReLU6<MatType>::ReLU6(
48
+ ReLU6&& other) :
49
49
  Layer<MatType>(std::move(other))
50
50
  {
51
51
  // Nothing to do here.
52
52
  }
53
53
 
54
54
  template<typename MatType>
55
- ReLU6Type<MatType>&
56
- ReLU6Type<MatType>::operator=(const ReLU6Type& other)
55
+ ReLU6<MatType>&
56
+ ReLU6<MatType>::operator=(const ReLU6& other)
57
57
  {
58
58
  if (&other != this)
59
59
  {
@@ -64,8 +64,8 @@ ReLU6Type<MatType>::operator=(const ReLU6Type& other)
64
64
  }
65
65
 
66
66
  template<typename MatType>
67
- ReLU6Type<MatType>&
68
- ReLU6Type<MatType>::operator=(ReLU6Type&& other)
67
+ ReLU6<MatType>&
68
+ ReLU6<MatType>::operator=(ReLU6&& other)
69
69
  {
70
70
  if (&other != this)
71
71
  {
@@ -76,14 +76,14 @@ ReLU6Type<MatType>::operator=(ReLU6Type&& other)
76
76
  }
77
77
 
78
78
  template<typename MatType>
79
- void ReLU6Type<MatType>::Forward(
79
+ void ReLU6<MatType>::Forward(
80
80
  const MatType& input, MatType& output)
81
81
  {
82
- output = arma::clamp(input, 0.0, 6.0);
82
+ output = arma::clamp(input, 0, 6);
83
83
  }
84
84
 
85
85
  template<typename MatType>
86
- void ReLU6Type<MatType>::Backward(
86
+ void ReLU6<MatType>::Backward(
87
87
  const MatType& input,
88
88
  const MatType& /* output */,
89
89
  const MatType& gy,
@@ -95,13 +95,13 @@ void ReLU6Type<MatType>::Backward(
95
95
  if (input(i) < 6 && input(i) > 0)
96
96
  g(i) = gy(i);
97
97
  else
98
- g(i) = 0.0;
98
+ g(i) = 0;
99
99
  }
100
100
  }
101
101
 
102
102
  template<typename MatType>
103
103
  template<typename Archive>
104
- void ReLU6Type<MatType>::serialize(
104
+ void ReLU6<MatType>::serialize(
105
105
  Archive& /* ar */,
106
106
  const uint32_t /* version */)
107
107
  {
@@ -30,18 +30,20 @@ namespace mlpack {
30
30
  * computation.
31
31
  */
32
32
  template <typename MatType = arma::mat>
33
- class RepeatType : public Layer<MatType>
33
+ class Repeat : public Layer<MatType>
34
34
  {
35
35
  public:
36
- //! Get Specific Col type, not only arma
36
+ // Convenience typedefs.
37
+ using ElemType = typename MatType::elem_type;
37
38
  using UintCol = typename GetUColType<MatType>::type;
38
39
  using UintMat = typename GetUDenseMatType<MatType>::type;
40
+
39
41
  /**
40
42
  * Create the Repeat object. Multiples will be empty (e.g. 1s for all
41
43
  * dimensions), so this is the equivalent of an Identity Layer.
42
44
  * Interleave will be false (e.g. repeat in blocks).
43
45
  */
44
- RepeatType();
46
+ Repeat();
45
47
 
46
48
  /**
47
49
  * Create the Repeat object, specifying the number of times to repeat
@@ -53,24 +55,24 @@ class RepeatType : public Layer<MatType>
53
55
  * @apram interleave If true, the output will be interleaved (similar to
54
56
  * arma::repelem). If false, the output will be repeated in blocks.
55
57
  */
56
- RepeatType(std::vector<size_t> multiples, bool interleave = false);
58
+ Repeat(std::vector<size_t> multiples, bool interleave = false);
57
59
 
58
60
  /**
59
61
  * Destroy the layers held by the model.
60
62
  */
61
- virtual ~RepeatType() { }
63
+ virtual ~Repeat() { }
62
64
 
63
- //! Clone the RepeatType object. This handles polymorphism correctly.
64
- RepeatType* Clone() const override { return new RepeatType(*this); }
65
+ // Clone the Repeat object. This handles polymorphism correctly.
66
+ Repeat* Clone() const override { return new Repeat(*this); }
65
67
 
66
- //! Copy the given RepeatType layer.
67
- RepeatType(const RepeatType& other);
68
- //! Take ownership of the given RepeatType layer.
69
- RepeatType(RepeatType&& other) noexcept;
70
- //! Copy the given RepeatType layer.
71
- RepeatType& operator=(const RepeatType& other);
72
- //! Take ownership of the given RepeatType layer.
73
- RepeatType& operator=(RepeatType&& other) noexcept;
68
+ // Copy the given Repeat layer.
69
+ Repeat(const Repeat& other);
70
+ // Take ownership of the given Repeat layer.
71
+ Repeat(Repeat&& other) noexcept;
72
+ // Copy the given Repeat layer.
73
+ Repeat& operator=(const Repeat& other);
74
+ // Take ownership of the given Repeat layer.
75
+ Repeat& operator=(Repeat&& other) noexcept;
74
76
 
75
77
  /**
76
78
  * Ordinary feed forward pass of a neural network, evaluating the function
@@ -96,20 +98,20 @@ class RepeatType : public Layer<MatType>
96
98
  const MatType& gy,
97
99
  MatType& g) override;
98
100
 
99
- //! Get the repeat multiples
101
+ // Get the repeat multiples
100
102
  const std::vector<size_t>& Multiples() const { return multiples; }
101
103
 
102
- //! Get the repeat multiples for modification
104
+ // Get the repeat multiples for modification
103
105
  std::vector<size_t>& Multiples()
104
106
  {
105
107
  this->validOutputDimensions = false;
106
108
  return multiples;
107
109
  }
108
110
 
109
- //! Get the interleave parameter
111
+ // Get the interleave parameter
110
112
  bool Interleave() const { return interleave; }
111
113
 
112
- //! Get the interleave parameter for modification
114
+ // Get the interleave parameter for modification
113
115
  bool& Interleave() { return interleave; }
114
116
 
115
117
  /**
@@ -130,10 +132,10 @@ class RepeatType : public Layer<MatType>
130
132
  void serialize(Archive& ar, const uint32_t /* version */);
131
133
 
132
134
  private:
133
- //! Parameter to indicate number of times to repeat along each dimension
135
+ // Parameter to indicate number of times to repeat along each dimension
134
136
  std::vector<size_t> multiples;
135
137
 
136
- //! Parameter to indicate whether to interleave the output
138
+ // Parameter to indicate whether to interleave the output
137
139
  bool interleave;
138
140
 
139
141
  // Cache the target indices for a single tensor for use
@@ -144,10 +146,7 @@ class RepeatType : public Layer<MatType>
144
146
  // input elements for use in the backward pass.
145
147
  size_t sizeMult;
146
148
  UintMat backIdxs;
147
- }; // class RepeatType.
148
-
149
- // Standard Repeat layer.
150
- using Repeat = RepeatType<arma::mat>;
149
+ }; // class Repeat.
151
150
 
152
151
  } // namespace mlpack
153
152