mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -24,6 +24,7 @@
24
24
  #include <mlpack/methods/ann/activation_functions/mish_function.hpp>
25
25
  #include <mlpack/methods/ann/activation_functions/lisht_function.hpp>
26
26
  #include <mlpack/methods/ann/activation_functions/gelu_function.hpp>
27
+ #include <mlpack/methods/ann/activation_functions/gelu_exact_function.hpp>
27
28
  #include <mlpack/methods/ann/activation_functions/elliot_function.hpp>
28
29
  #include <mlpack/methods/ann/activation_functions/elish_function.hpp>
29
30
  #include <mlpack/methods/ann/activation_functions/gaussian_function.hpp>
@@ -51,6 +52,7 @@ namespace mlpack {
51
52
  * - Mish
52
53
  * - LiSHT
53
54
  * - GELU
55
+ * - GELUExact
54
56
  * - ELiSH
55
57
  * - Elliot
56
58
  * - Gaussian
@@ -68,6 +70,9 @@ template <
68
70
  class BaseLayer : public Layer<MatType>
69
71
  {
70
72
  public:
73
+ // Convenience typedef to access the element type of the weights and data.
74
+ using ElemType = typename MatType::elem_type;
75
+
71
76
  /**
72
77
  * Create the BaseLayer object.
73
78
  */
@@ -83,7 +88,7 @@ class BaseLayer : public Layer<MatType>
83
88
  // members.
84
89
 
85
90
  //! Clone the BaseLayer object. This handles polymorphism correctly.
86
- BaseLayer* Clone() const { return new BaseLayer(*this); }
91
+ virtual BaseLayer* Clone() const { return new BaseLayer(*this); }
87
92
 
88
93
  /**
89
94
  * Forward pass: apply the activation to the inputs.
@@ -131,138 +136,110 @@ class BaseLayer : public Layer<MatType>
131
136
  /**
132
137
  * Standard Sigmoid-Layer using the logistic activation function.
133
138
  */
134
- using Sigmoid = BaseLayer<LogisticFunction, arma::mat>;
135
-
136
139
  template<typename MatType = arma::mat>
137
- using SigmoidType = BaseLayer<LogisticFunction, MatType>;
140
+ using Sigmoid = BaseLayer<LogisticFunction, MatType>;
138
141
 
139
142
  /**
140
143
  * Standard rectified linear unit non-linearity layer.
141
144
  */
142
- using ReLU = BaseLayer<RectifierFunction, arma::mat>;
143
-
144
145
  template<typename MatType = arma::mat>
145
- using ReLUType = BaseLayer<RectifierFunction, MatType>;
146
+ using ReLU = BaseLayer<RectifierFunction, MatType>;
146
147
 
147
148
  /**
148
149
  * Standard hyperbolic tangent layer.
149
150
  */
150
- using TanH = BaseLayer<TanhFunction, arma::mat>;
151
-
152
151
  template<typename MatType = arma::mat>
153
- using TanHType = BaseLayer<TanhFunction, MatType>;
152
+ using TanH = BaseLayer<TanhFunction, MatType>;
154
153
 
155
154
  /**
156
155
  * Standard Softplus-Layer using the Softplus activation function.
157
156
  */
158
- using SoftPlus = BaseLayer<SoftplusFunction, arma::mat>;
159
-
160
157
  template<typename MatType = arma::mat>
161
- using SoftPlusType = BaseLayer<SoftplusFunction, MatType>;
158
+ using SoftPlus = BaseLayer<SoftplusFunction, MatType>;
162
159
 
163
160
  /**
164
161
  * Standard HardSigmoid-Layer using the HardSigmoid activation function.
165
162
  */
166
- using HardSigmoid = BaseLayer<HardSigmoidFunction, arma::mat>;
167
-
168
163
  template<typename MatType = arma::mat>
169
- using HardSigmoidType = BaseLayer<HardSigmoidFunction, MatType>;
164
+ using HardSigmoid = BaseLayer<HardSigmoidFunction, MatType>;
170
165
 
171
166
  /**
172
167
  * Standard Swish-Layer using the Swish activation function.
173
168
  */
174
- using Swish = BaseLayer<SwishFunction, arma::mat>;
175
-
176
169
  template<typename MatType = arma::mat>
177
- using SwishType = BaseLayer<SwishFunction, MatType>;
170
+ using Swish = BaseLayer<SwishFunction, MatType>;
178
171
 
179
172
  /**
180
173
  * Standard Mish-Layer using the Mish activation function.
181
174
  */
182
- using Mish = BaseLayer<MishFunction, arma::mat>;
183
-
184
175
  template<typename MatType = arma::mat>
185
- using MishType = BaseLayer<MishFunction, MatType>;
176
+ using Mish = BaseLayer<MishFunction, MatType>;
186
177
 
187
178
  /**
188
179
  * Standard LiSHT-Layer using the LiSHT activation function.
189
180
  */
190
- using LiSHT = BaseLayer<LiSHTFunction, arma::mat>;
191
-
192
181
  template<typename MatType = arma::mat>
193
- using LiSHTType = BaseLayer<LiSHTFunction, MatType>;
182
+ using LiSHT = BaseLayer<LiSHTFunction, MatType>;
194
183
 
195
184
  /**
196
185
  * Standard GELU-Layer using the GELU activation function.
197
186
  */
198
- using GELU = BaseLayer<GELUFunction, arma::mat>;
187
+ template<typename MatType = arma::mat>
188
+ using GELU = BaseLayer<GELUFunction, MatType>;
199
189
 
190
+ /**
191
+ * Standard GELUExact-Layer using the GELUExact activation function.
192
+ */
200
193
  template<typename MatType = arma::mat>
201
- using GELUType = BaseLayer<GELUFunction, MatType>;
194
+ using GELUExact = BaseLayer<GELUExactFunction, MatType>;
202
195
 
203
196
  /**
204
197
  * Standard Elliot-Layer using the Elliot activation function.
205
198
  */
206
- using Elliot = BaseLayer<ElliotFunction, arma::mat>;
207
-
208
199
  template<typename MatType = arma::mat>
209
- using ElliotType = BaseLayer<ElliotFunction, MatType>;
200
+ using Elliot = BaseLayer<ElliotFunction, MatType>;
210
201
 
211
202
  /**
212
203
  * Standard ELiSH-Layer using the ELiSH activation function.
213
204
  */
214
- using Elish = BaseLayer<ElishFunction, arma::mat>;
215
-
216
205
  template<typename MatType = arma::mat>
217
- using ElishType = BaseLayer<ElishFunction, MatType>;
206
+ using Elish = BaseLayer<ElishFunction, MatType>;
218
207
 
219
208
  /**
220
209
  * Standard Gaussian-Layer using the Gaussian activation function.
221
210
  */
222
- using Gaussian = BaseLayer<GaussianFunction, arma::mat>;
223
-
224
211
  template<typename MatType = arma::mat>
225
- using GaussianType = BaseLayer<GaussianFunction, MatType>;
212
+ using Gaussian = BaseLayer<GaussianFunction, MatType>;
226
213
 
227
214
  /**
228
215
  * Standard HardSwish-Layer using the HardSwish activation function.
229
216
  */
230
- using HardSwish = BaseLayer<HardSwishFunction, arma::mat>;
231
-
232
217
  template <typename MatType = arma::mat>
233
- using HardSwishType = BaseLayer<HardSwishFunction, MatType>;
218
+ using HardSwish = BaseLayer<HardSwishFunction, MatType>;
234
219
 
235
220
  /**
236
221
  * Standard TanhExp-Layer using the TanhExp activation function.
237
222
  */
238
- using TanhExp = BaseLayer<TanhExpFunction, arma::mat>;
239
-
240
223
  template<typename MatType = arma::mat>
241
- using TanhExpType = BaseLayer<TanhExpFunction, MatType>;
224
+ using TanhExp = BaseLayer<TanhExpFunction, MatType>;
242
225
 
243
226
  /**
244
227
  * Standard SILU-Layer using the SILU activation function.
245
228
  */
246
- using SILU = BaseLayer<SILUFunction, arma::mat>;
247
-
248
229
  template<typename MatType = arma::mat>
249
- using SILUType = BaseLayer<SILUFunction, MatType>;
230
+ using SILU = BaseLayer<SILUFunction, MatType>;
250
231
 
251
232
  /**
252
233
  * Standard Hyper Sinh layer.
253
234
  */
254
- using HyperSinh = BaseLayer<HyperSinhFunction, arma::mat>;
255
-
256
235
  template<typename MatType = arma::mat>
257
- using HyperSinhType = BaseLayer<HyperSinhFunction, MatType>;
236
+ using HyperSinh = BaseLayer<HyperSinhFunction, MatType>;
258
237
 
259
238
  /**
260
239
  * Standard Bipolar Sigmoid layer.
261
240
  */
262
- using BipolarSigmoid = BaseLayer<BipolarSigmoidFunction, arma::mat>;
263
-
264
241
  template<typename MatType = arma::mat>
265
- using BipolarSigmoidType = BaseLayer<BipolarSigmoidFunction, MatType>;
242
+ using BipolarSigmoid = BaseLayer<BipolarSigmoidFunction, MatType>;
266
243
 
267
244
  } // namespace mlpack
268
245
 
@@ -50,10 +50,13 @@ namespace mlpack {
50
50
  * computation.
51
51
  */
52
52
  template <typename MatType = arma::mat>
53
- class BatchNormType : public Layer<MatType>
53
+ class BatchNorm : public Layer<MatType>
54
54
  {
55
55
  public:
56
+ // Convenience typedefs to access the element type of the weights and data.
57
+ using ElemType = typename MatType::elem_type;
56
58
  using CubeType = typename GetCubeType<MatType>::type;
59
+
57
60
  /**
58
61
  * Create the BatchNorm object.
59
62
  *
@@ -72,7 +75,7 @@ class BatchNormType : public Layer<MatType>
72
75
  * three dimensions rows, columns and slices), and `minAxis` & `maxAxis` is
73
76
  * 2, then we apply the same normalization across different slices.
74
77
  */
75
- BatchNormType();
78
+ BatchNorm();
76
79
 
77
80
  /**
78
81
  * Create the BatchNorm layer object for a specified axis of input units as
@@ -93,30 +96,30 @@ class BatchNormType : public Layer<MatType>
93
96
  * updating the parameters or momentum is used.
94
97
  * @param momentum Parameter used to to update the running mean and variance.
95
98
  */
96
- BatchNormType(const size_t minAxis,
99
+ BatchNorm(const size_t minAxis,
97
100
  const size_t maxAxis,
98
101
  const double eps = 1e-8,
99
102
  const bool average = true,
100
103
  const double momentum = 0.1);
101
104
 
102
- virtual ~BatchNormType() { }
105
+ virtual ~BatchNorm() { }
103
106
 
104
- //! Clone the BatchNormType object. This handles polymorphism correctly.
105
- BatchNormType* Clone() const { return new BatchNormType(*this); }
107
+ //! Clone the BatchNorm object. This handles polymorphism correctly.
108
+ BatchNorm* Clone() const { return new BatchNorm(*this); }
106
109
 
107
110
  //! Copy the other BatchNorm layer (but not weights).
108
- BatchNormType(const BatchNormType& layer);
111
+ BatchNorm(const BatchNorm& layer);
109
112
 
110
113
  //! Take ownership of the members of the other BatchNorm layer (but not
111
114
  //! weights).
112
- BatchNormType(BatchNormType&& layer);
115
+ BatchNorm(BatchNorm&& layer);
113
116
 
114
117
  //! Copy the other BatchNorm layer (but not weights).
115
- BatchNormType& operator=(const BatchNormType& layer);
118
+ BatchNorm& operator=(const BatchNorm& layer);
116
119
 
117
120
  //! Take ownership of the members of the other BatchNorm layer (but not
118
121
  //! weights).
119
- BatchNormType& operator=(BatchNormType&& layer);
122
+ BatchNorm& operator=(BatchNorm&& layer);
120
123
 
121
124
  /**
122
125
  * Reset the layer parameters.
@@ -189,7 +192,7 @@ class BatchNormType : public Layer<MatType>
189
192
  MatType& TrainingVariance() { return runningVariance; }
190
193
 
191
194
  //! Get the number of input units / channels.
192
- size_t InputSize() const { return size; }
195
+ size_t InputSize() const { return inputUnits; }
193
196
 
194
197
  //! Get the epsilon value.
195
198
  const double &Epsilon() const { return eps; }
@@ -203,7 +206,7 @@ class BatchNormType : public Layer<MatType>
203
206
  bool Average() const { return average; }
204
207
 
205
208
  //! Get size of weights.
206
- size_t WeightSize() const { return 2 * size; }
209
+ size_t WeightSize() const { return 2 * inputUnits; }
207
210
 
208
211
  //! Compute the output dimensions of the layer given `InputDimensions()`.
209
212
  void ComputeOutputDimensions();
@@ -253,7 +256,7 @@ class BatchNormType : public Layer<MatType>
253
256
 
254
257
  //! Locally-stored number of input units. (This is the product of all
255
258
  //! dimensions between minAxis and maxAxis, inclusive.)
256
- size_t size;
259
+ size_t inputUnits;
257
260
 
258
261
  //! Locally-stored number of higher dimension we are not applying
259
262
  //! batch normalization to. This is the product of this->inputDimensions
@@ -273,11 +276,6 @@ class BatchNormType : public Layer<MatType>
273
276
  CubeType inputMean;
274
277
  }; // class BatchNorm
275
278
 
276
- // Convenience typedefs.
277
-
278
- // Standard Adaptive max pooling layer.
279
- using BatchNorm = BatchNormType<arma::mat>;
280
-
281
279
  } // namespace mlpack
282
280
 
283
281
  // Include the implementation.
@@ -22,7 +22,7 @@
22
22
  namespace mlpack {
23
23
 
24
24
  template<typename MatType>
25
- BatchNormType<MatType>::BatchNormType() :
25
+ BatchNorm<MatType>::BatchNorm() :
26
26
  Layer<MatType>(),
27
27
  minAxis(2),
28
28
  maxAxis(2),
@@ -31,14 +31,14 @@ BatchNormType<MatType>::BatchNormType() :
31
31
  momentum(0.0),
32
32
  count(0),
33
33
  inputDimension(1),
34
- size(0),
34
+ inputUnits(0),
35
35
  higherDimension(1)
36
36
  {
37
37
  // Nothing to do here.
38
38
  }
39
39
 
40
40
  template <typename MatType>
41
- BatchNormType<MatType>::BatchNormType(
41
+ BatchNorm<MatType>::BatchNorm(
42
42
  const size_t minAxis,
43
43
  const size_t maxAxis,
44
44
  const double eps,
@@ -52,7 +52,7 @@ BatchNormType<MatType>::BatchNormType(
52
52
  momentum(momentum),
53
53
  count(0),
54
54
  inputDimension(1),
55
- size(0),
55
+ inputUnits(0),
56
56
  higherDimension(1)
57
57
  {
58
58
  // Nothing to do here.
@@ -60,7 +60,7 @@ BatchNormType<MatType>::BatchNormType(
60
60
 
61
61
  // Copy constructor.
62
62
  template<typename MatType>
63
- BatchNormType<MatType>::BatchNormType(const BatchNormType& layer) :
63
+ BatchNorm<MatType>::BatchNorm(const BatchNorm& layer) :
64
64
  Layer<MatType>(layer),
65
65
  minAxis(layer.minAxis),
66
66
  maxAxis(layer.maxAxis),
@@ -70,7 +70,7 @@ BatchNormType<MatType>::BatchNormType(const BatchNormType& layer) :
70
70
  variance(layer.variance),
71
71
  count(layer.count),
72
72
  inputDimension(layer.inputDimension),
73
- size(layer.size),
73
+ inputUnits(layer.inputUnits),
74
74
  higherDimension(layer.higherDimension),
75
75
  runningMean(layer.runningMean),
76
76
  runningVariance(layer.runningVariance)
@@ -80,7 +80,7 @@ BatchNormType<MatType>::BatchNormType(const BatchNormType& layer) :
80
80
 
81
81
  // Move constructor.
82
82
  template<typename MatType>
83
- BatchNormType<MatType>::BatchNormType(BatchNormType&& layer) :
83
+ BatchNorm<MatType>::BatchNorm(BatchNorm&& layer) :
84
84
  Layer<MatType>(std::move(layer)),
85
85
  minAxis(std::move(layer.minAxis)),
86
86
  maxAxis(std::move(layer.maxAxis)),
@@ -90,7 +90,7 @@ BatchNormType<MatType>::BatchNormType(BatchNormType&& layer) :
90
90
  variance(std::move(layer.variance)),
91
91
  count(std::move(layer.count)),
92
92
  inputDimension(std::move(layer.inputDimension)),
93
- size(std::move(layer.size)),
93
+ inputUnits(std::move(layer.inputUnits)),
94
94
  higherDimension(std::move(layer.higherDimension)),
95
95
  runningMean(std::move(layer.runningMean)),
96
96
  runningVariance(std::move(layer.runningVariance))
@@ -99,8 +99,8 @@ BatchNormType<MatType>::BatchNormType(BatchNormType&& layer) :
99
99
  }
100
100
 
101
101
  template<typename MatType>
102
- BatchNormType<MatType>&
103
- BatchNormType<MatType>::operator=(const BatchNormType& layer)
102
+ BatchNorm<MatType>&
103
+ BatchNorm<MatType>::operator=(const BatchNorm& layer)
104
104
  {
105
105
  if (&layer != this)
106
106
  {
@@ -113,7 +113,7 @@ BatchNormType<MatType>::operator=(const BatchNormType& layer)
113
113
  variance = layer.variance;
114
114
  count = layer.count;
115
115
  inputDimension = layer.inputDimension;
116
- size = layer.size;
116
+ inputUnits = layer.inputUnits;
117
117
  higherDimension = layer.higherDimension;
118
118
  runningMean = layer.runningMean;
119
119
  runningVariance = layer.runningVariance;
@@ -123,9 +123,9 @@ BatchNormType<MatType>::operator=(const BatchNormType& layer)
123
123
  }
124
124
 
125
125
  template<typename MatType>
126
- BatchNormType<MatType>&
127
- BatchNormType<MatType>::operator=(
128
- BatchNormType&& layer)
126
+ BatchNorm<MatType>&
127
+ BatchNorm<MatType>::operator=(
128
+ BatchNorm&& layer)
129
129
  {
130
130
  if (&layer != this)
131
131
  {
@@ -138,7 +138,7 @@ BatchNormType<MatType>::operator=(
138
138
  variance = std::move(layer.variance);
139
139
  count = std::move(layer.count);
140
140
  inputDimension = std::move(layer.inputDimension);
141
- size = std::move(layer.size);
141
+ inputUnits = std::move(layer.inputUnits);
142
142
  higherDimension = std::move(layer.higherDimension);
143
143
  runningMean = std::move(layer.runningMean);
144
144
  runningVariance = std::move(layer.runningVariance);
@@ -148,40 +148,40 @@ BatchNormType<MatType>::operator=(
148
148
  }
149
149
 
150
150
  template<typename MatType>
151
- void BatchNormType<MatType>::SetWeights(const MatType& weightsIn)
151
+ void BatchNorm<MatType>::SetWeights(const MatType& weightsIn)
152
152
  {
153
153
  MakeAlias(weights, weightsIn, WeightSize(), 1);
154
154
  // Gamma acts as the scaling parameters for the normalized output.
155
- MakeAlias(gamma, weightsIn, size, 1);
155
+ MakeAlias(gamma, weightsIn, inputUnits, 1);
156
156
  // Beta acts as the shifting parameters for the normalized output.
157
- MakeAlias(beta, weightsIn, size, 1, gamma.n_elem);
157
+ MakeAlias(beta, weightsIn, inputUnits, 1, gamma.n_elem);
158
158
  }
159
159
 
160
160
  template<typename MatType>
161
- void BatchNormType<MatType>::CustomInitialize(
161
+ void BatchNorm<MatType>::CustomInitialize(
162
162
  MatType& W,
163
163
  const size_t elements)
164
164
  {
165
- if (elements != 2 * size) {
166
- throw std::invalid_argument("BatchNormType::CustomInitialize(): wrong "
165
+ if (elements != 2 * inputUnits) {
166
+ throw std::invalid_argument("BatchNorm::CustomInitialize(): wrong "
167
167
  "elements size!");
168
168
  }
169
169
  MatType gammaTemp;
170
170
  MatType betaTemp;
171
171
  // Gamma acts as the scaling parameters for the normalized output.
172
- MakeAlias(gammaTemp, W, size, 1);
172
+ MakeAlias(gammaTemp, W, inputUnits, 1);
173
173
  // Beta acts as the shifting parameters for the normalized output.
174
- MakeAlias(betaTemp, W, size, 1, gammaTemp.n_elem);
174
+ MakeAlias(betaTemp, W, inputUnits, 1, gammaTemp.n_elem);
175
175
 
176
- gammaTemp.fill(1.0);
177
- betaTemp.fill(0.0);
176
+ gammaTemp.ones();
177
+ betaTemp.zeros();
178
178
 
179
- runningMean.zeros(size, 1);
180
- runningVariance.ones(size, 1);
179
+ runningMean.zeros(inputUnits, 1);
180
+ runningVariance.ones(inputUnits, 1);
181
181
  }
182
182
 
183
183
  template<typename MatType>
184
- void BatchNormType<MatType>::Forward(
184
+ void BatchNorm<MatType>::Forward(
185
185
  const MatType& input,
186
186
  MatType& output)
187
187
  {
@@ -203,31 +203,32 @@ void BatchNormType<MatType>::Forward(
203
203
  // Input corresponds to output from previous layer.
204
204
  // Used a cube for simplicity.
205
205
  CubeType inputTemp;
206
- MakeAlias(inputTemp, input, inputSize, size,
206
+ MakeAlias(inputTemp, input, inputSize, inputUnits,
207
207
  batchSize * higherDimension, 0, false);
208
208
 
209
209
  // Initialize output to same size and values for convenience.
210
210
  CubeType outputTemp;
211
- MakeAlias(outputTemp, output, inputSize, size,
211
+ MakeAlias(outputTemp, output, inputSize, inputUnits,
212
212
  batchSize * higherDimension, 0, false);
213
213
  outputTemp = inputTemp;
214
214
 
215
215
  // Calculate mean and variance over all channels.
216
216
  MatType mean = sum(sum(inputTemp, 2), 0) / m;
217
- variance = sum(sum(pow(
218
- inputTemp.each_slice() - repmat(mean, inputSize, 1), 2), 2), 0) / m;
217
+ variance = sum(sum(square(
218
+ inputTemp.each_slice() - repmat(mean, inputSize, 1)), 2), 0) / m;
219
219
 
220
220
  outputTemp.each_slice() -= repmat(mean, inputSize, 1);
221
221
 
222
222
  // Used in backward propagation.
223
- inputMean.set_size(arma::size(inputTemp));
223
+ inputMean.set_size(size(inputTemp));
224
224
  inputMean = outputTemp;
225
225
 
226
226
  // Normalize output.
227
- outputTemp.each_slice() /= sqrt(repmat(variance, inputSize, 1) + eps);
227
+ outputTemp.each_slice() /= sqrt(repmat(variance, inputSize, 1) +
228
+ ElemType(eps));
228
229
 
229
230
  // Re-used in backward propagation.
230
- normalized.set_size(arma::size(inputTemp));
231
+ normalized.set_size(size(inputTemp));
231
232
  normalized = outputTemp;
232
233
 
233
234
  outputTemp.each_slice() %= repmat(gamma.t(), inputSize, 1);
@@ -235,11 +236,11 @@ void BatchNormType<MatType>::Forward(
235
236
 
236
237
  count += 1;
237
238
  // Value for average factor which used to update running parameters.
238
- double averageFactor = average ? 1.0 / count : momentum;
239
+ ElemType averageFactor = ElemType(average ? 1.0 / count : momentum);
239
240
 
240
- double nElements = 0.0;
241
+ ElemType nElements = 0;
241
242
  if (m - 1 != 0)
242
- nElements = m * (1.0 / (m - 1));
243
+ nElements = m * (ElemType(1) / (m - 1));
243
244
 
244
245
  // Update running mean and running variance.
245
246
  runningMean = (1 - averageFactor) * runningMean + averageFactor *
@@ -252,35 +253,35 @@ void BatchNormType<MatType>::Forward(
252
253
  // Normalize the input and scale and shift the output.
253
254
  output = input;
254
255
  CubeType outputTemp;
255
- MakeAlias(outputTemp, output, inputSize, size,
256
+ MakeAlias(outputTemp, output, inputSize, inputUnits,
256
257
  batchSize * higherDimension, 0, false);
257
258
 
258
259
  outputTemp.each_slice() -= repmat(runningMean.t(), inputSize, 1);
259
260
  outputTemp.each_slice() /= sqrt(repmat(runningVariance.t(),
260
- inputSize, 1) + eps);
261
+ inputSize, 1) + ElemType(eps));
261
262
  outputTemp.each_slice() %= repmat(gamma.t(), inputSize, 1);
262
263
  outputTemp.each_slice() += repmat(beta.t(), inputSize, 1);
263
264
  }
264
265
  }
265
266
 
266
267
  template<typename MatType>
267
- void BatchNormType<MatType>::Backward(
268
+ void BatchNorm<MatType>::Backward(
268
269
  const MatType& /* input */,
269
270
  const MatType& /* output */,
270
271
  const MatType& gy,
271
272
  MatType& g)
272
273
  {
273
- const MatType stdInv = 1.0 / sqrt(variance + eps);
274
+ const MatType stdInv = 1 / sqrt(variance + ElemType(eps));
274
275
 
275
276
  const size_t batchSize = gy.n_cols;
276
277
  const size_t inputSize = inputDimension;
277
278
  const size_t m = inputSize * batchSize * higherDimension;
278
279
 
279
280
  CubeType gyTemp;
280
- MakeAlias(gyTemp, gy, inputSize, size,
281
+ MakeAlias(gyTemp, gy, inputSize, inputUnits,
281
282
  batchSize * higherDimension, 0, false);
282
283
  CubeType gTemp;
283
- MakeAlias(gTemp, g, inputSize, size,
284
+ MakeAlias(gTemp, g, inputSize, inputUnits,
284
285
  batchSize * higherDimension, 0, false);
285
286
 
286
287
  // Step 1: dl / dxhat.
@@ -288,24 +289,24 @@ void BatchNormType<MatType>::Backward(
288
289
 
289
290
  // Step 2: sum dl / dxhat * (x - mu) * -0.5 * stdInv^3.
290
291
  MatType temp = sum(sum(norm % inputMean, 2), 0);
291
- MatType vars = temp % pow(stdInv, 3) * (-0.5);
292
+ MatType vars = -temp % pow(stdInv, 3) / 2;
292
293
 
293
294
  // Step 3: dl / dxhat * 1 / stdInv + variance * 2 * (x - mu) / m +
294
295
  // dl / dmu * 1 / m.
295
296
  gTemp = (norm.each_slice() % repmat(stdInv, inputSize, 1)) +
296
- ((inputMean.each_slice() % repmat(vars, inputSize, 1) * 2.0) / m);
297
+ ((inputMean.each_slice() % repmat(vars, inputSize, 1) * 2) / m);
297
298
 
298
299
  // Step 4: sum (dl / dxhat * -1 / stdInv) + variance *
299
300
  // sum (-2 * (x - mu)) / m.
300
301
  MatType normTemp = sum(sum((norm.each_slice() %
301
302
  repmat(-stdInv, inputSize, 1)) +
302
- (inputMean.each_slice() % repmat(vars, inputSize, 1) * (-2.0) / m),
303
+ -2 * (inputMean.each_slice() % repmat(vars, inputSize, 1) / m),
303
304
  2), 0) / m;
304
305
  gTemp.each_slice() += repmat(normTemp, inputSize, 1);
305
306
  }
306
307
 
307
308
  template<typename MatType>
308
- void BatchNormType<MatType>::Gradient(
309
+ void BatchNorm<MatType>::Gradient(
309
310
  const MatType& /* input */,
310
311
  const MatType& error,
311
312
  MatType& gradient)
@@ -313,7 +314,7 @@ void BatchNormType<MatType>::Gradient(
313
314
  const size_t inputSize = inputDimension;
314
315
 
315
316
  CubeType errorTemp;
316
- MakeAlias(errorTemp, error, inputSize, size,
317
+ MakeAlias(errorTemp, error, inputSize, inputUnits,
317
318
  error.n_cols * higherDimension, 0, false);
318
319
 
319
320
  // Step 5: dl / dy * xhat.
@@ -326,7 +327,7 @@ void BatchNormType<MatType>::Gradient(
326
327
  }
327
328
 
328
329
  template<typename MatType>
329
- void BatchNormType<MatType>::ComputeOutputDimensions()
330
+ void BatchNorm<MatType>::ComputeOutputDimensions()
330
331
  {
331
332
  if (minAxis > maxAxis)
332
333
  {
@@ -354,9 +355,9 @@ void BatchNormType<MatType>::ComputeOutputDimensions()
354
355
  for (size_t i = 0; i < mainMinAxis; i++)
355
356
  inputDimension *= this->inputDimensions[i];
356
357
 
357
- size = this->inputDimensions[mainMinAxis];
358
+ inputUnits = this->inputDimensions[mainMinAxis];
358
359
  for (size_t i = mainMinAxis + 1; i <= mainMaxAxis; i++)
359
- size *= this->inputDimensions[i];
360
+ inputUnits *= this->inputDimensions[i];
360
361
 
361
362
  higherDimension = 1;
362
363
  for (size_t i = mainMaxAxis + 1; i < this->inputDimensions.size(); i++)
@@ -365,7 +366,7 @@ void BatchNormType<MatType>::ComputeOutputDimensions()
365
366
 
366
367
  template<typename MatType>
367
368
  template<typename Archive>
368
- void BatchNormType<MatType>::serialize(
369
+ void BatchNorm<MatType>::serialize(
369
370
  Archive& ar, const uint32_t /* version */)
370
371
  {
371
372
  ar(cereal::base_class<Layer<MatType>>(this));
@@ -380,7 +381,7 @@ void BatchNormType<MatType>::serialize(
380
381
  ar(CEREAL_NVP(runningVariance));
381
382
  ar(CEREAL_NVP(inputMean));
382
383
  ar(CEREAL_NVP(inputDimension));
383
- ar(CEREAL_NVP(size));
384
+ ar(CEREAL_NVP(inputUnits));
384
385
  ar(CEREAL_NVP(higherDimension));
385
386
  }
386
387