mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -18,7 +18,7 @@ namespace mlpack {
18
18
 
19
19
  // Create the LinearRecurrent layer.
20
20
  template<typename MatType, typename RegularizerType>
21
- LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType() :
21
+ LinearRecurrent<MatType, RegularizerType>::LinearRecurrent() :
22
22
  RecurrentLayer<MatType>(),
23
23
  inSize(0),
24
24
  outSize(0)
@@ -27,7 +27,7 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType() :
27
27
  }
28
28
 
29
29
  template<typename MatType, typename RegularizerType>
30
- LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
30
+ LinearRecurrent<MatType, RegularizerType>::LinearRecurrent(
31
31
  const size_t outSize,
32
32
  RegularizerType regularizer) :
33
33
  RecurrentLayer<MatType>(),
@@ -40,8 +40,8 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
40
40
 
41
41
  // Copy constructor.
42
42
  template<typename MatType, typename RegularizerType>
43
- LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
44
- const LinearRecurrentType& layer) :
43
+ LinearRecurrent<MatType, RegularizerType>::LinearRecurrent(
44
+ const LinearRecurrent& layer) :
45
45
  RecurrentLayer<MatType>(layer),
46
46
  inSize(layer.inSize),
47
47
  outSize(layer.outSize),
@@ -52,8 +52,8 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
52
52
 
53
53
  // Move constructor.
54
54
  template<typename MatType, typename RegularizerType>
55
- LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
56
- LinearRecurrentType&& layer) :
55
+ LinearRecurrent<MatType, RegularizerType>::LinearRecurrent(
56
+ LinearRecurrent&& layer) :
57
57
  RecurrentLayer<MatType>(std::move(layer)),
58
58
  inSize(std::move(layer.inSize)),
59
59
  outSize(std::move(layer.outSize)),
@@ -66,9 +66,9 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
66
66
 
67
67
  // Copy operator.
68
68
  template<typename MatType, typename RegularizerType>
69
- LinearRecurrentType<MatType, RegularizerType>&
70
- LinearRecurrentType<MatType, RegularizerType>::operator=(
71
- const LinearRecurrentType& layer)
69
+ LinearRecurrent<MatType, RegularizerType>&
70
+ LinearRecurrent<MatType, RegularizerType>::operator=(
71
+ const LinearRecurrent& layer)
72
72
  {
73
73
  if (&layer != this)
74
74
  {
@@ -83,9 +83,9 @@ LinearRecurrentType<MatType, RegularizerType>::operator=(
83
83
 
84
84
  // Move operator.
85
85
  template<typename MatType, typename RegularizerType>
86
- LinearRecurrentType<MatType, RegularizerType>&
87
- LinearRecurrentType<MatType, RegularizerType>::operator=(
88
- LinearRecurrentType&& layer)
86
+ LinearRecurrent<MatType, RegularizerType>&
87
+ LinearRecurrent<MatType, RegularizerType>::operator=(
88
+ LinearRecurrent&& layer)
89
89
  {
90
90
  if (&layer != this)
91
91
  {
@@ -104,7 +104,7 @@ LinearRecurrentType<MatType, RegularizerType>::operator=(
104
104
 
105
105
  // Set the parameters of the layer.
106
106
  template<typename MatType, typename RegularizerType>
107
- void LinearRecurrentType<MatType, RegularizerType>::SetWeights(
107
+ void LinearRecurrent<MatType, RegularizerType>::SetWeights(
108
108
  const MatType& weightsIn)
109
109
  {
110
110
  MakeAlias(parameters, weightsIn, WeightSize(), 1);
@@ -116,7 +116,7 @@ void LinearRecurrentType<MatType, RegularizerType>::SetWeights(
116
116
 
117
117
  // Forward pass of linear recurrent layer.
118
118
  template<typename MatType, typename RegularizerType>
119
- void LinearRecurrentType<MatType, RegularizerType>::Forward(
119
+ void LinearRecurrent<MatType, RegularizerType>::Forward(
120
120
  const MatType& input, MatType& output)
121
121
  {
122
122
  // Take the forward step: f(x) = Wx + Uh + b.
@@ -127,7 +127,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Forward(
127
127
  else
128
128
  {
129
129
  output = weights * input +
130
- recurrentWeights * this->RecurrentState(this->PreviousStep());
130
+ recurrentWeights * previousOutput;
131
131
  }
132
132
 
133
133
  #pragma omp for
@@ -136,12 +136,12 @@ void LinearRecurrentType<MatType, RegularizerType>::Forward(
136
136
 
137
137
  // Update the recurrent state if needed.
138
138
  if (!this->AtFinalStep())
139
- this->RecurrentState(this->CurrentStep()) = output;
139
+ currentOutput = output;
140
140
  }
141
141
 
142
142
  // Backward pass of linear recurrent layer.
143
143
  template<typename MatType, typename RegularizerType>
144
- void LinearRecurrentType<MatType, RegularizerType>::Backward(
144
+ void LinearRecurrent<MatType, RegularizerType>::Backward(
145
145
  const MatType& /* input */,
146
146
  const MatType& /* output */,
147
147
  const MatType& gy,
@@ -159,7 +159,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Backward(
159
159
  {
160
160
  // Via the recurrence, the result is equivalent, just with the recurrent
161
161
  // gradient as the gy parameter.
162
- g += weights.t() * this->RecurrentGradient(this->CurrentStep());
162
+ g += weights.t() * currentGradient;
163
163
  }
164
164
 
165
165
  if (this->HasPreviousStep())
@@ -169,20 +169,19 @@ void LinearRecurrentType<MatType, RegularizerType>::Backward(
169
169
  //
170
170
  // With respect to the output, we can just propagate back through the
171
171
  // recurrent weights.
172
- this->RecurrentGradient(this->PreviousStep()) = recurrentWeights.t() * gy;
172
+ previousGradient = recurrentWeights.t() * gy;
173
173
 
174
174
  if (!this->AtFinalStep())
175
175
  {
176
176
  // If we also have a path from dz/dh^t, this can be added.
177
- this->RecurrentGradient(this->PreviousStep()) +=
178
- recurrentWeights.t() * this->RecurrentGradient(this->CurrentStep());
177
+ previousGradient += recurrentWeights.t() * currentGradient;
179
178
  }
180
179
  }
181
180
  }
182
181
 
183
182
  // Compute the gradient with respect to the input.
184
183
  template<typename MatType, typename RegularizerType>
185
- void LinearRecurrentType<MatType, RegularizerType>::Gradient(
184
+ void LinearRecurrent<MatType, RegularizerType>::Gradient(
186
185
  const MatType& input,
187
186
  const MatType& error,
188
187
  MatType& gradient)
@@ -204,7 +203,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Gradient(
204
203
  if (this->HasPreviousStep())
205
204
  {
206
205
  gradient.submat(whOffset, 0, bOffset - 1, 0) =
207
- vectorise(error * this->RecurrentState(this->PreviousStep()).t());
206
+ vectorise(error * previousOutput.t());
208
207
  }
209
208
  gradient.submat(bOffset, 0, gradient.n_rows - 1, 0) = sum(error, 1);
210
209
 
@@ -215,15 +214,14 @@ void LinearRecurrentType<MatType, RegularizerType>::Gradient(
215
214
  if (!this->AtFinalStep())
216
215
  {
217
216
  gradient.submat(0, 0, whOffset - 1, 0) +=
218
- vectorise(this->RecurrentGradient(this->CurrentStep()) * input.t());
217
+ vectorise(currentGradient * input.t());
219
218
  if (this->HasPreviousStep())
220
219
  {
221
220
  gradient.submat(whOffset, 0, bOffset - 1, 0) +=
222
- vectorise(this->RecurrentGradient(this->CurrentStep()) *
223
- this->RecurrentState(this->PreviousStep()).t());
221
+ vectorise(currentGradient * previousOutput.t());
224
222
  }
225
223
  gradient.submat(bOffset, 0, gradient.n_rows - 1, 0) += sum(
226
- this->RecurrentGradient(this->CurrentStep()), 1);
224
+ currentGradient, 1);
227
225
 
228
226
  // this->HiddenDeriv(this->PreviousStep()) was already computed in
229
227
  // Backward(), so no need to do it here.
@@ -232,7 +230,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Gradient(
232
230
 
233
231
  // Get the total number of trainable parameters.
234
232
  template<typename MatType, typename RegularizerType>
235
- size_t LinearRecurrentType<MatType, RegularizerType>::WeightSize() const
233
+ size_t LinearRecurrent<MatType, RegularizerType>::WeightSize() const
236
234
  {
237
235
  return (inSize * outSize) /* weight matrix */ +
238
236
  (outSize * outSize) /* recurrent state matrix */ +
@@ -240,7 +238,7 @@ size_t LinearRecurrentType<MatType, RegularizerType>::WeightSize() const
240
238
  }
241
239
 
242
240
  template<typename MatType, typename RegularizerType>
243
- size_t LinearRecurrentType<MatType, RegularizerType>::RecurrentSize() const
241
+ size_t LinearRecurrent<MatType, RegularizerType>::RecurrentSize() const
244
242
  {
245
243
  return outSize;
246
244
  }
@@ -248,7 +246,7 @@ size_t LinearRecurrentType<MatType, RegularizerType>::RecurrentSize() const
248
246
  // Compute the output dimensions of the layer, assuming that inputDimension has
249
247
  // been set.
250
248
  template<typename MatType, typename RegularizerType>
251
- void LinearRecurrentType<MatType, RegularizerType>::ComputeOutputDimensions()
249
+ void LinearRecurrent<MatType, RegularizerType>::ComputeOutputDimensions()
252
250
  {
253
251
  // Compute the total number of input dimensions.
254
252
  inSize = this->inputDimensions[0];
@@ -261,10 +259,41 @@ void LinearRecurrentType<MatType, RegularizerType>::ComputeOutputDimensions()
261
259
  this->outputDimensions[0] = outSize;
262
260
  }
263
261
 
262
+ template<typename MatType, typename RegularizerType>
263
+ void LinearRecurrent<MatType, RegularizerType>::OnStepChanged(
264
+ const size_t step,
265
+ const size_t /* batchSize */,
266
+ const size_t activeBatchSize,
267
+ const bool backwards)
268
+ {
269
+ // Make aliases for the output from the recurrent state.
270
+ MakeAlias(currentOutput, this->RecurrentState(step),
271
+ outSize, activeBatchSize);
272
+
273
+ if (this->HasPreviousStep())
274
+ {
275
+ MakeAlias(previousOutput, this->RecurrentState(this->PreviousStep()),
276
+ outSize, activeBatchSize);
277
+ }
278
+
279
+ // Make aliases for the gradient from the recurrent gradient.
280
+ if (backwards)
281
+ {
282
+ MakeAlias(currentGradient, this->RecurrentGradient(step),
283
+ outSize, activeBatchSize);
284
+
285
+ if (this->HasPreviousStep())
286
+ {
287
+ MakeAlias(previousGradient, this->RecurrentGradient(this->PreviousStep()),
288
+ outSize, activeBatchSize);
289
+ }
290
+ }
291
+ }
292
+
264
293
  // Serialize the layer.
265
294
  template<typename MatType, typename RegularizerType>
266
295
  template<typename Archive>
267
- void LinearRecurrentType<MatType, RegularizerType>::serialize(
296
+ void LinearRecurrent<MatType, RegularizerType>::serialize(
268
297
  Archive& ar, const uint32_t /* version */)
269
298
  {
270
299
  ar(cereal::base_class<RecurrentLayer<MatType>>(this));
@@ -29,37 +29,31 @@ namespace mlpack {
29
29
  * computation.
30
30
  */
31
31
  template <typename MatType = arma::mat>
32
- class LogSoftMaxType : public Layer<MatType>
32
+ class LogSoftMax : public Layer<MatType>
33
33
  {
34
34
  public:
35
+ // Convenience typedef to access the element type of the weights and data.
36
+ using ElemType = typename MatType::elem_type;
37
+
35
38
  /**
36
39
  * Create the LogSoftmax layer.
37
40
  */
38
- LogSoftMaxType();
41
+ LogSoftMax();
39
42
 
40
- //! Clone the LogSoftMaxType object. This handles polymorphism correctly.
41
- LogSoftMaxType* Clone() const { return new LogSoftMaxType(*this); }
43
+ //! Clone the LogSoftMax object. This handles polymorphism correctly.
44
+ LogSoftMax* Clone() const { return new LogSoftMax(*this); }
42
45
 
43
46
  // Virtual destructor.
44
- virtual ~LogSoftMaxType() { }
47
+ virtual ~LogSoftMax() { }
45
48
 
46
- //! Copy the given LogSoftMaxType.
47
- LogSoftMaxType(const LogSoftMaxType& other);
48
- //! Take ownership of the given LogSoftMaxType.
49
- LogSoftMaxType(LogSoftMaxType&& other);
50
- //! Copy the given LogSoftMaxType.
51
- LogSoftMaxType& operator=(const LogSoftMaxType& other);
52
- //! Take ownership of the given LogSoftMaxType.
53
- LogSoftMaxType& operator=(LogSoftMaxType&& other);
54
-
55
- /**
56
- * A wrapper function to call the correct implementation according to the
57
- * specific matrix type (e.g., arma, coot).
58
- *
59
- * @param input Input data used for evaluating the specified function.
60
- * @param output Resulting output activation.
61
- */
62
- void Forward(const MatType& input, MatType& output);
49
+ //! Copy the given LogSoftMax.
50
+ LogSoftMax(const LogSoftMax& other);
51
+ //! Take ownership of the given LogSoftMax.
52
+ LogSoftMax(LogSoftMax&& other);
53
+ //! Copy the given LogSoftMax.
54
+ LogSoftMax& operator=(const LogSoftMax& other);
55
+ //! Take ownership of the given LogSoftMax.
56
+ LogSoftMax& operator=(LogSoftMax&& other);
63
57
 
64
58
  /**
65
59
  * Ordinary feed forward pass of a neural network, evaluating the function
@@ -68,15 +62,7 @@ class LogSoftMaxType : public Layer<MatType>
68
62
  * @param input Input data used for evaluating the specified function.
69
63
  * @param output Resulting output activation.
70
64
  */
71
- void ForwardImpl(const MatType& input, MatType& output,
72
- const typename std::enable_if_t<
73
- arma::is_arma_type<MatType>::value>* = 0);
74
-
75
- #ifdef MLPACK_HAS_COOT
76
- void ForwardImpl(const MatType& input, MatType& output,
77
- const typename std::enable_if_t<
78
- coot::is_coot_type<MatType>::value>* = 0);
79
- #endif
65
+ void Forward(const MatType& input, MatType& output);
80
66
 
81
67
  /**
82
68
  * Ordinary feed backward pass of a neural network, calculating the function
@@ -101,11 +87,6 @@ class LogSoftMaxType : public Layer<MatType>
101
87
  }
102
88
  }; // class LogSoftmaxType
103
89
 
104
- // Convenience typedefs.
105
-
106
- // Standard Linear layer using no regularization.
107
- using LogSoftMax = LogSoftMaxType<arma::mat>;
108
-
109
90
  } // namespace mlpack
110
91
 
111
92
  // Include implementation.
@@ -18,29 +18,29 @@
18
18
  namespace mlpack {
19
19
 
20
20
  template<typename MatType>
21
- LogSoftMaxType<MatType>::LogSoftMaxType() :
21
+ LogSoftMax<MatType>::LogSoftMax() :
22
22
  Layer<MatType>()
23
23
  {
24
24
  // Nothing to do here.
25
25
  }
26
26
 
27
27
  template<typename MatType>
28
- LogSoftMaxType<MatType>::LogSoftMaxType(const LogSoftMaxType& other) :
28
+ LogSoftMax<MatType>::LogSoftMax(const LogSoftMax& other) :
29
29
  Layer<MatType>(other)
30
30
  {
31
31
  // Nothing to do here.
32
32
  }
33
33
 
34
34
  template<typename MatType>
35
- LogSoftMaxType<MatType>::LogSoftMaxType(LogSoftMaxType&& other) :
35
+ LogSoftMax<MatType>::LogSoftMax(LogSoftMax&& other) :
36
36
  Layer<MatType>(std::move(other))
37
37
  {
38
38
  // Nothing to do here.
39
39
  }
40
40
 
41
41
  template<typename MatType>
42
- LogSoftMaxType<MatType>&
43
- LogSoftMaxType<MatType>::operator=(const LogSoftMaxType& other)
42
+ LogSoftMax<MatType>&
43
+ LogSoftMax<MatType>::operator=(const LogSoftMax& other)
44
44
  {
45
45
  if (&other != this)
46
46
  {
@@ -51,8 +51,8 @@ LogSoftMaxType<MatType>::operator=(const LogSoftMaxType& other)
51
51
  }
52
52
 
53
53
  template<typename MatType>
54
- LogSoftMaxType<MatType>&
55
- LogSoftMaxType<MatType>::operator=(LogSoftMaxType&& other)
54
+ LogSoftMax<MatType>&
55
+ LogSoftMax<MatType>::operator=(LogSoftMax&& other)
56
56
  {
57
57
  if (&other != this)
58
58
  {
@@ -63,85 +63,69 @@ LogSoftMaxType<MatType>::operator=(LogSoftMaxType&& other)
63
63
  }
64
64
 
65
65
  template<typename MatType>
66
- void LogSoftMaxType<MatType>::Forward(const MatType& input, MatType& output)
66
+ void LogSoftMax<MatType>::Forward(const MatType& input, MatType& output)
67
67
  {
68
- ForwardImpl(input, output);
69
- }
70
-
71
- template<typename MatType>
72
- void LogSoftMaxType<MatType>::ForwardImpl(
73
- const MatType& input,
74
- MatType& output,
75
- const typename std::enable_if_t<arma::is_arma_type<MatType>::value>*)
76
- {
77
- MatType maxInput = repmat(max(input, 0), input.n_rows, 1);
78
- output = (maxInput - input);
79
-
80
- // Approximation of the base-e exponential function. The accuracy, however, is
81
- // about 0.00001 lower than using exp. Credits go to Leon Bottou.
82
- #pragma omp parallel for
83
- for (size_t i = 0; i < output.n_elem; ++i)
68
+ if constexpr (IsArma<MatType>::value)
84
69
  {
85
- double x = output(i);
86
- //! Fast approximation of exp(-x) for x positive.
87
- static constexpr double A0 = 1.0;
88
- static constexpr double A1 = 0.125;
89
- static constexpr double A2 = 0.0078125;
90
- static constexpr double A3 = 0.00032552083;
91
- static constexpr double A4 = 1.0172526e-5;
92
-
93
- if (x < 13.0)
70
+ MatType maxInput = repmat(max(input, 0), input.n_rows, 1);
71
+ output = (maxInput - input);
72
+
73
+ // Approximation of the base-e exponential function. The accuracy, however,
74
+ // is about 0.00001 lower than using exp. Credits go to Leon Bottou.
75
+ #pragma omp parallel for
76
+ for (size_t i = 0; i < output.n_elem; ++i)
94
77
  {
95
- double y = A0 + x * (A1 + x * (A2 + x * (A3 + x * A4)));
96
- y *= y;
97
- y *= y;
98
- y *= y;
99
- y = 1 / y;
100
- output(i) = y;
78
+ double x = output(i);
79
+ //! Fast approximation of exp(-x) for x positive.
80
+ static constexpr double A0 = 1.0;
81
+ static constexpr double A1 = 0.125;
82
+ static constexpr double A2 = 0.0078125;
83
+ static constexpr double A3 = 0.00032552083;
84
+ static constexpr double A4 = 1.0172526e-5;
85
+
86
+ if (x < 13.0)
87
+ {
88
+ double y = A0 + x * (A1 + x * (A2 + x * (A3 + x * A4)));
89
+ y *= y;
90
+ y *= y;
91
+ y *= y;
92
+ y = 1 / y;
93
+ output(i) = ElemType(y);
94
+ }
95
+ else
96
+ {
97
+ output(i) = 0;
98
+ }
101
99
  }
102
- else
100
+
101
+ #pragma omp parallel for
102
+ for (size_t col = 0; col < maxInput.n_cols; ++col)
103
103
  {
104
- output(i) = 0.0;
104
+ ElemType colSum = 0;
105
+ for (size_t row = 0; row < output.n_rows; ++row)
106
+ {
107
+ colSum += output(row, col);
108
+ }
109
+ ElemType logSum = std::log(colSum);
110
+ for (size_t row = 0; row < maxInput.n_rows; ++row)
111
+ {
112
+ maxInput(row, col) += logSum;
113
+ }
105
114
  }
115
+ output = input - maxInput;
106
116
  }
107
-
108
- #pragma omp parallel for
109
- for (size_t col = 0; col < maxInput.n_cols; ++col)
117
+ else if constexpr (IsCoot<MatType>::value)
110
118
  {
111
- double colSum = 0.0;
112
- for (size_t row = 0; row < output.n_rows; ++row)
113
- {
114
- colSum += output(row, col);
115
- }
116
- double logSum = std::log(colSum);
117
- for (size_t row = 0; row < maxInput.n_rows; ++row)
118
- {
119
- maxInput(row, col) += logSum;
120
- }
119
+ MatType maxInput = repmat(max(input), input.n_rows, 1);
120
+ output = (maxInput - input);
121
+ output = exp(-output);
122
+ maxInput.each_row() += log(sum(output));
123
+ output = input - maxInput;
121
124
  }
122
-
123
- output = input - maxInput;
124
- }
125
-
126
- #ifdef MLPACK_HAS_COOT
127
-
128
- template<typename MatType>
129
- void LogSoftMaxType<MatType>::ForwardImpl(
130
- const MatType& input,
131
- MatType& output,
132
- const typename std::enable_if_t<coot::is_coot_type<MatType>::value>*)
133
- {
134
- MatType maxInput = repmat(max(input), input.n_rows, 1);
135
- output = (maxInput - input);
136
- output = exp(output * -1);
137
- maxInput.each_row() += log(sum(output));
138
- output = input - maxInput;
139
125
  }
140
126
 
141
- #endif
142
-
143
127
  template<typename MatType>
144
- void LogSoftMaxType<MatType>::Backward(
128
+ void LogSoftMax<MatType>::Backward(
145
129
  const MatType& /* input */,
146
130
  const MatType& output,
147
131
  const MatType& gy,
@@ -53,11 +53,14 @@ namespace mlpack {
53
53
  * computation.
54
54
  */
55
55
  template<typename MatType = arma::mat>
56
- class LSTMType : public RecurrentLayer<MatType>
56
+ class LSTM : public RecurrentLayer<MatType>
57
57
  {
58
58
  public:
59
- //! Create the LSTM object.
60
- LSTMType();
59
+ // Convenience typedef to access the element type of the weights and data.
60
+ using ElemType = typename MatType::elem_type;
61
+
62
+ // Create the LSTM object.
63
+ LSTM();
61
64
 
62
65
  /**
63
66
  * Create the LSTM layer object using the specified parameters.
@@ -65,21 +68,21 @@ class LSTMType : public RecurrentLayer<MatType>
65
68
  * @param outSize The number of output units.
66
69
  * @param rho Maximum number of steps to backpropagate through time (BPTT).
67
70
  */
68
- LSTMType(const size_t outSize);
71
+ LSTM(const size_t outSize);
69
72
 
70
- //! Clone the LSTMType object. This handles polymorphism correctly.
71
- LSTMType* Clone() const { return new LSTMType(*this); }
73
+ // Clone the LSTM object. This handles polymorphism correctly.
74
+ LSTM* Clone() const { return new LSTM(*this); }
72
75
 
73
- //! Copy the given LSTMType object.
74
- LSTMType(const LSTMType& other);
75
- //! Take ownership of the given LSTMType object's data.
76
- LSTMType(LSTMType&& other);
77
- //! Copy the given LSTMType object.
78
- LSTMType& operator=(const LSTMType& other);
79
- //! Take ownership of the given LSTMType object's data.
80
- LSTMType& operator=(LSTMType&& other);
76
+ // Copy the given LSTM object.
77
+ LSTM(const LSTM& other);
78
+ // Take ownership of the given LSTM object's data.
79
+ LSTM(LSTM&& other);
80
+ // Copy the given LSTM object.
81
+ LSTM& operator=(const LSTM& other);
82
+ // Take ownership of the given LSTM object's data.
83
+ LSTM& operator=(LSTM&& other);
81
84
 
82
- virtual ~LSTMType() { }
85
+ virtual ~LSTM() { }
83
86
 
84
87
  /**
85
88
  * Reset the layer parameter. The method is called to
@@ -217,6 +220,12 @@ class LSTMType : public RecurrentLayer<MatType>
217
220
  this->outputDimensions[0] = outSize;
218
221
  }
219
222
 
223
+ // Update the internal aliases of the layer when the step changes.
224
+ void OnStepChanged(const size_t step,
225
+ const size_t batchSize,
226
+ const size_t activeBatchSize,
227
+ const bool backwards);
228
+
220
229
  /**
221
230
  * Serialize the layer.
222
231
  */
@@ -287,20 +296,8 @@ class LSTMType : public RecurrentLayer<MatType>
287
296
  MatType nextDeltaForgetGate;
288
297
  MatType nextDeltaOutputGate;
289
298
  MatType nextDeltaCell;
290
-
291
- // Calling this function will set all the aliases for the functions above to
292
- // the correct places in the current recurrent state methods.
293
- void SetInternalAliases(const size_t batchSize);
294
-
295
- // Calling this function will set up workspace memory for the backward pass,
296
- // if necessary.
297
- void SetBackwardWorkspace(const size_t batchSize);
298
- }; // class LSTMType
299
-
300
- // Convenience typedefs.
301
-
302
- // Standard LSTM layer.
303
- using LSTM = LSTMType<arma::mat>;
299
+ MatType nextForgetGate;
300
+ }; // class LSTM
304
301
 
305
302
  } // namespace mlpack
306
303