mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -18,7 +18,7 @@
18
18
  namespace mlpack {
19
19
 
20
20
  template<typename MatType>
21
- LSTMType<MatType>::LSTMType() :
21
+ LSTM<MatType>::LSTM() :
22
22
  RecurrentLayer<MatType>(),
23
23
  inSize(0),
24
24
  outSize(0)
@@ -27,7 +27,7 @@ LSTMType<MatType>::LSTMType() :
27
27
  }
28
28
 
29
29
  template<typename MatType>
30
- LSTMType<MatType>::LSTMType(const size_t outSize) :
30
+ LSTM<MatType>::LSTM(const size_t outSize) :
31
31
  RecurrentLayer<MatType>(),
32
32
  inSize(0),
33
33
  outSize(outSize)
@@ -36,7 +36,7 @@ LSTMType<MatType>::LSTMType(const size_t outSize) :
36
36
  }
37
37
 
38
38
  template<typename MatType>
39
- LSTMType<MatType>::LSTMType(const LSTMType& layer) :
39
+ LSTM<MatType>::LSTM(const LSTM& layer) :
40
40
  RecurrentLayer<MatType>(layer),
41
41
  inSize(layer.inSize),
42
42
  outSize(layer.outSize)
@@ -45,7 +45,7 @@ LSTMType<MatType>::LSTMType(const LSTMType& layer) :
45
45
  }
46
46
 
47
47
  template<typename MatType>
48
- LSTMType<MatType>::LSTMType(LSTMType&& layer) :
48
+ LSTM<MatType>::LSTM(LSTM&& layer) :
49
49
  RecurrentLayer<MatType>(std::move(layer)),
50
50
  inSize(layer.inSize),
51
51
  outSize(layer.outSize)
@@ -55,7 +55,7 @@ LSTMType<MatType>::LSTMType(LSTMType&& layer) :
55
55
  }
56
56
 
57
57
  template<typename MatType>
58
- LSTMType<MatType>& LSTMType<MatType>::operator=(const LSTMType& layer)
58
+ LSTM<MatType>& LSTM<MatType>::operator=(const LSTM& layer)
59
59
  {
60
60
  if (this != &layer)
61
61
  {
@@ -68,7 +68,7 @@ LSTMType<MatType>& LSTMType<MatType>::operator=(const LSTMType& layer)
68
68
  }
69
69
 
70
70
  template<typename MatType>
71
- LSTMType<MatType>& LSTMType<MatType>::operator=(LSTMType&& layer)
71
+ LSTM<MatType>& LSTM<MatType>::operator=(LSTM&& layer)
72
72
  {
73
73
  if (this != &layer)
74
74
  {
@@ -84,7 +84,7 @@ LSTMType<MatType>& LSTMType<MatType>::operator=(LSTMType&& layer)
84
84
  }
85
85
 
86
86
  template<typename MatType>
87
- void LSTMType<MatType>::SetWeights(const MatType& weights)
87
+ void LSTM<MatType>::SetWeights(const MatType& weights)
88
88
  {
89
89
  // Set the weight parameters for the inputs.
90
90
  const size_t inputWeightSize = outSize * inSize;
@@ -123,14 +123,10 @@ void LSTMType<MatType>::SetWeights(const MatType& weights)
123
123
 
124
124
  // Forward when cellState is not needed.
125
125
  template<typename MatType>
126
- void LSTMType<MatType>::Forward(const MatType& input, MatType& output)
126
+ void LSTM<MatType>::Forward(const MatType& input, MatType& output)
127
127
  {
128
128
  // Convenience alias.
129
- const size_t batchSize = input.n_cols;
130
-
131
- // The internal quantities are stored as recurrent state; so, set aliases
132
- // correctly for this time step.
133
- SetInternalAliases(batchSize);
129
+ const size_t activeBatchSize = input.n_cols;
134
130
 
135
131
  // Compute internal state:
136
132
  //
@@ -142,25 +138,29 @@ void LSTMType<MatType>::Forward(const MatType& input, MatType& output)
142
138
  // y_t = tanh(c_t) % o_t
143
139
 
144
140
  // Start by computing all non-recurrent portions.
145
- blockInput = blockInputWeight * input + repmat(blockInputBias, 1, batchSize);
146
- inputGate = inputGateWeight * input + repmat(inputGateBias, 1, batchSize);
147
- forgetGate = forgetGateWeight * input + repmat(forgetGateBias, 1, batchSize);
148
- outputGate = outputGateWeight * input + repmat(outputGateBias, 1, batchSize);
141
+ blockInput = blockInputWeight * input + repmat(blockInputBias, 1,
142
+ activeBatchSize);
143
+ inputGate = inputGateWeight * input + repmat(inputGateBias, 1,
144
+ activeBatchSize);
145
+ forgetGate = forgetGateWeight * input + repmat(forgetGateBias, 1,
146
+ activeBatchSize);
147
+ outputGate = outputGateWeight * input + repmat(outputGateBias, 1,
148
+ activeBatchSize);
149
149
 
150
150
  // Now add in recurrent portions, if needed.
151
151
  if (this->HasPreviousStep())
152
152
  {
153
153
  blockInput += recurrentBlockInputWeight * prevRecurrent;
154
154
  inputGate += recurrentInputGateWeight * prevRecurrent +
155
- repmat(peepholeInputGateWeight, 1, batchSize) % prevCell;
155
+ repmat(peepholeInputGateWeight, 1, activeBatchSize) % prevCell;
156
156
  forgetGate += recurrentForgetGateWeight * prevRecurrent +
157
- repmat(peepholeForgetGateWeight, 1, batchSize) % prevCell;
157
+ repmat(peepholeForgetGateWeight, 1, activeBatchSize) % prevCell;
158
158
  }
159
159
 
160
160
  // Apply nonlinearities. (TODO: fast sigmoid?)
161
161
  blockInput = tanh(blockInput);
162
- inputGate = 1.0 / (1.0 + exp(-inputGate));
163
- forgetGate = 1.0 / (1.0 + exp(-forgetGate));
162
+ inputGate = 1 / (1 + exp(-inputGate));
163
+ forgetGate = 1 / (1 + exp(-forgetGate));
164
164
 
165
165
  // Compute the cell state.
166
166
  if (this->HasPreviousStep())
@@ -172,17 +172,18 @@ void LSTMType<MatType>::Forward(const MatType& input, MatType& output)
172
172
  if (this->HasPreviousStep())
173
173
  {
174
174
  outputGate += recurrentOutputGateWeight * prevRecurrent +
175
- repmat(peepholeOutputGateWeight, 1, batchSize) % thisCell;
175
+ repmat(peepholeOutputGateWeight, 1, activeBatchSize) % thisCell;
176
176
  }
177
177
  else
178
178
  {
179
179
  // If we don't have a previous step, we still have to consider the peephole
180
180
  // connection.
181
- outputGate += repmat(peepholeOutputGateWeight, 1, batchSize) % thisCell;
181
+ outputGate += repmat(peepholeOutputGateWeight, 1, activeBatchSize) %
182
+ thisCell;
182
183
  }
183
184
 
184
185
  // Apply nonlinearity for output gate.
185
- outputGate = 1.0 / (1.0 + exp(-outputGate));
186
+ outputGate = 1 / (1 + exp(-outputGate));
186
187
 
187
188
  // Finally, we can compute the output itself.
188
189
  output = tanh(thisCell) % outputGate;
@@ -193,7 +194,7 @@ void LSTMType<MatType>::Forward(const MatType& input, MatType& output)
193
194
  }
194
195
 
195
196
  template<typename MatType>
196
- void LSTMType<MatType>::Backward(
197
+ void LSTM<MatType>::Backward(
197
198
  const MatType& /* input */,
198
199
  const MatType& output,
199
200
  const MatType& gy,
@@ -219,12 +220,7 @@ void LSTMType<MatType>::Backward(
219
220
  // dz_t = dc_t % i_t % (1 - z_t .^ 2)
220
221
  //
221
222
  // dx_t = W_z^T dz_t + W_i^T di_t + W_f^T df_t + W_o^T do_t
222
- //
223
- // Before we start, set all the internal aliases, which will contain this time
224
- // step's values as computed in Forward().
225
- const size_t batchSize = output.n_cols;
226
- SetInternalAliases(batchSize);
227
- SetBackwardWorkspace(batchSize);
223
+ const size_t activeBatchSize = output.n_cols;
228
224
 
229
225
  // First attempt...
230
226
  if (this->AtFinalStep())
@@ -239,35 +235,31 @@ void LSTMType<MatType>::Backward(
239
235
  recurrentOutputGateWeight.t() * nextDeltaOutputGate;
240
236
  }
241
237
 
242
- deltaOutputGate = deltaY % tanh(thisCell) % (outputGate % (1.0 - outputGate));
238
+ deltaOutputGate = deltaY % tanh(thisCell) % (outputGate % (1 - outputGate));
243
239
 
244
240
  // Only first two terms if at final step
245
241
  if (this->AtFinalStep())
246
242
  {
247
- deltaCell = deltaY % outputGate % (1.0 - square(tanh(thisCell))) +
248
- repmat(peepholeOutputGateWeight, 1, batchSize) % deltaOutputGate;
243
+ deltaCell = deltaY % outputGate % (1 - square(tanh(thisCell))) +
244
+ repmat(peepholeOutputGateWeight, 1, activeBatchSize) % deltaOutputGate;
249
245
  }
250
246
  else
251
247
  {
252
- // To update the cell state, we actually need to use the forget gate values
253
- // from the next time step.
254
- MatType nextForgetGate;
255
- MakeAlias(nextForgetGate, this->RecurrentState(this->CurrentStep() + 1),
256
- outSize, batchSize, 4 * outSize * batchSize);
257
-
258
- deltaCell = deltaY % outputGate % (1.0 - square(tanh(thisCell))) +
259
- repmat(peepholeOutputGateWeight, 1, batchSize) % deltaOutputGate +
260
- repmat(peepholeInputGateWeight, 1, batchSize) % nextDeltaInputGate +
261
- repmat(peepholeForgetGateWeight, 1, batchSize) % nextDeltaForgetGate +
248
+ deltaCell = deltaY % outputGate % (1 - square(tanh(thisCell))) +
249
+ repmat(peepholeOutputGateWeight, 1, activeBatchSize) % deltaOutputGate +
250
+ repmat(peepholeInputGateWeight, 1, activeBatchSize) %
251
+ nextDeltaInputGate +
252
+ repmat(peepholeForgetGateWeight, 1, activeBatchSize) %
253
+ nextDeltaForgetGate +
262
254
  nextDeltaCell % nextForgetGate;
263
255
  }
264
256
 
265
257
  if (this->HasPreviousStep())
266
- deltaForgetGate = deltaCell % prevCell % (forgetGate % (1.0 - forgetGate));
258
+ deltaForgetGate = deltaCell % prevCell % (forgetGate % (1 - forgetGate));
267
259
  else
268
260
  deltaForgetGate.zeros();
269
- deltaInputGate = deltaCell % blockInput % (inputGate % (1.0 - inputGate));
270
- deltaBlockInput = deltaCell % inputGate % (1.0 - square(blockInput));
261
+ deltaInputGate = deltaCell % blockInput % (inputGate % (1 - inputGate));
262
+ deltaBlockInput = deltaCell % inputGate % (1 - square(blockInput));
271
263
 
272
264
  // Finally, compute deltaX (which is what we wanted all along).
273
265
  g = blockInputWeight.t() * deltaBlockInput +
@@ -280,15 +272,11 @@ void LSTMType<MatType>::Backward(
280
272
  }
281
273
 
282
274
  template<typename MatType>
283
- void LSTMType<MatType>::Gradient(
275
+ void LSTM<MatType>::Gradient(
284
276
  const MatType& input,
285
277
  const MatType& /* error */,
286
278
  MatType& gradient)
287
279
  {
288
- // This implementation depends on Gradient() being called just after
289
- // Backward(), which is something we can safely assume. So, the workspace
290
- // aliases are already set by SetBackwardWorkspace().
291
- //
292
280
  // In this implementation we won't use aliases; we'll just address the correct
293
281
  // part of the gradient directly.
294
282
 
@@ -390,7 +378,7 @@ void LSTMType<MatType>::Gradient(
390
378
  }
391
379
 
392
380
  template<typename MatType>
393
- size_t LSTMType<MatType>::WeightSize() const
381
+ size_t LSTM<MatType>::WeightSize() const
394
382
  {
395
383
  return 4 * inSize * outSize /* input weight connections */ +
396
384
  4 * outSize /* input bias */ +
@@ -399,7 +387,7 @@ size_t LSTMType<MatType>::WeightSize() const
399
387
  }
400
388
 
401
389
  template<typename MatType>
402
- size_t LSTMType<MatType>::RecurrentSize() const
390
+ size_t LSTM<MatType>::RecurrentSize() const
403
391
  {
404
392
  // We have to account for the cell, recurrent connection, and the four
405
393
  // internal matrices: block input, input gate, forget gate, and output gate.
@@ -410,97 +398,113 @@ size_t LSTMType<MatType>::RecurrentSize() const
410
398
  }
411
399
 
412
400
  template<typename MatType>
413
- void LSTMType<MatType>::SetInternalAliases(const size_t batchSize)
401
+ void LSTM<MatType>::OnStepChanged(const size_t step,
402
+ const size_t batchSize,
403
+ const size_t activeBatchSize,
404
+ const bool backwards)
414
405
  {
415
406
  // Make all of the aliases for internal state point to the correct place.
416
- MatType& state = this->RecurrentState(this->CurrentStep());
407
+ MatType& state = this->RecurrentState(step);
417
408
 
418
409
  // First make aliases for the recurrent connections.
419
- MakeAlias(thisRecurrent, state, outSize, batchSize);
420
- MakeAlias(thisCell, state, outSize, batchSize, outSize * batchSize);
410
+ MakeAlias(thisRecurrent, state, outSize, activeBatchSize);
411
+ MakeAlias(thisCell, state, outSize, activeBatchSize, outSize * batchSize);
421
412
 
422
413
  // Now make aliases for the internal state members that we use as scratch
423
414
  // space for computation.
424
- MakeAlias(blockInput, state, outSize, batchSize, 2 * outSize * batchSize);
425
- MakeAlias(inputGate, state, outSize, batchSize, 3 * outSize * batchSize);
426
- MakeAlias(forgetGate, state, outSize, batchSize, 4 * outSize * batchSize);
427
- MakeAlias(outputGate, state, outSize, batchSize, 5 * outSize * batchSize);
415
+ MakeAlias(blockInput, state, outSize, activeBatchSize, 2 * outSize *
416
+ batchSize);
417
+ MakeAlias(inputGate, state, outSize, activeBatchSize, 3 * outSize *
418
+ batchSize);
419
+ MakeAlias(forgetGate, state, outSize, activeBatchSize, 4 * outSize *
420
+ batchSize);
421
+ MakeAlias(outputGate, state, outSize, activeBatchSize, 5 * outSize *
422
+ batchSize);
428
423
 
429
424
  // Make aliases for the previous time step, too, if we can.
430
425
  if (this->HasPreviousStep())
431
426
  {
432
427
  MatType& prevState = this->RecurrentState(this->PreviousStep());
433
428
 
434
- MakeAlias(prevRecurrent, prevState, outSize, batchSize);
435
- MakeAlias(prevCell, prevState, outSize, batchSize, outSize * batchSize);
429
+ MakeAlias(prevRecurrent, prevState, outSize, activeBatchSize);
430
+ MakeAlias(prevCell, prevState, outSize, activeBatchSize, outSize *
431
+ batchSize);
436
432
  }
437
- }
438
-
439
- template<typename MatType>
440
- void LSTMType<MatType>::SetBackwardWorkspace(const size_t batchSize)
441
- {
442
- // We need to hold enough space for two time steps.
443
- workspace.set_size(12 * outSize, batchSize);
444
433
 
445
- if (this->CurrentStep() % 2 == 0)
446
- {
447
- MakeAlias(deltaY, workspace, outSize, batchSize);
448
- MakeAlias(deltaBlockInput, workspace, outSize, batchSize,
449
- outSize * batchSize);
450
- MakeAlias(deltaInputGate, workspace, outSize, batchSize,
451
- 2 * outSize * batchSize);
452
- MakeAlias(deltaForgetGate, workspace, outSize, batchSize,
453
- 3 * outSize * batchSize);
454
- MakeAlias(deltaOutputGate, workspace, outSize, batchSize,
455
- 4 * outSize * batchSize);
456
- MakeAlias(deltaCell, workspace, outSize, batchSize,
457
- 5 * outSize * batchSize);
458
-
459
- MakeAlias(nextDeltaY, workspace, outSize, batchSize,
460
- 6 * outSize * batchSize);
461
- MakeAlias(nextDeltaBlockInput, workspace, outSize, batchSize,
462
- 7 * outSize * batchSize);
463
- MakeAlias(nextDeltaInputGate, workspace, outSize, batchSize,
464
- 8 * outSize * batchSize);
465
- MakeAlias(nextDeltaForgetGate, workspace, outSize, batchSize,
466
- 9 * outSize * batchSize);
467
- MakeAlias(nextDeltaOutputGate, workspace, outSize, batchSize,
468
- 10 * outSize * batchSize);
469
- MakeAlias(nextDeltaCell, workspace, outSize, batchSize,
470
- 11 * outSize * batchSize);
471
- }
472
- else
434
+ // Also set the workspaces for the backwards pass, if requested.
435
+ if (backwards)
473
436
  {
474
- MakeAlias(nextDeltaY, workspace, outSize, batchSize);
475
- MakeAlias(nextDeltaBlockInput, workspace, outSize, batchSize,
476
- outSize * batchSize);
477
- MakeAlias(nextDeltaInputGate, workspace, outSize, batchSize,
478
- 2 * outSize * batchSize);
479
- MakeAlias(nextDeltaForgetGate, workspace, outSize, batchSize,
480
- 3 * outSize * batchSize);
481
- MakeAlias(nextDeltaOutputGate, workspace, outSize, batchSize,
482
- 4 * outSize * batchSize);
483
- MakeAlias(nextDeltaCell, workspace, outSize, batchSize,
484
- 5 * outSize * batchSize);
485
-
486
- MakeAlias(deltaY, workspace, outSize, batchSize,
487
- 6 * outSize * batchSize);
488
- MakeAlias(deltaBlockInput, workspace, outSize, batchSize,
489
- 7 * outSize * batchSize);
490
- MakeAlias(deltaInputGate, workspace, outSize, batchSize,
491
- 8 * outSize * batchSize);
492
- MakeAlias(deltaForgetGate, workspace, outSize, batchSize,
493
- 9 * outSize * batchSize);
494
- MakeAlias(deltaOutputGate, workspace, outSize, batchSize,
495
- 10 * outSize * batchSize);
496
- MakeAlias(deltaCell, workspace, outSize, batchSize,
497
- 11 * outSize * batchSize);
437
+ // We need to hold enough space for two time steps.
438
+ workspace.set_size(12 * outSize, batchSize);
439
+
440
+ if (step % 2 == 0)
441
+ {
442
+ MakeAlias(deltaY, workspace, outSize, activeBatchSize);
443
+ MakeAlias(deltaBlockInput, workspace, outSize, activeBatchSize,
444
+ outSize * batchSize);
445
+ MakeAlias(deltaInputGate, workspace, outSize, activeBatchSize,
446
+ 2 * outSize * batchSize);
447
+ MakeAlias(deltaForgetGate, workspace, outSize, activeBatchSize,
448
+ 3 * outSize * batchSize);
449
+ MakeAlias(deltaOutputGate, workspace, outSize, activeBatchSize,
450
+ 4 * outSize * batchSize);
451
+ MakeAlias(deltaCell, workspace, outSize, activeBatchSize,
452
+ 5 * outSize * batchSize);
453
+
454
+ MakeAlias(nextDeltaY, workspace, outSize, activeBatchSize,
455
+ 6 * outSize * batchSize);
456
+ MakeAlias(nextDeltaBlockInput, workspace, outSize, activeBatchSize,
457
+ 7 * outSize * batchSize);
458
+ MakeAlias(nextDeltaInputGate, workspace, outSize, activeBatchSize,
459
+ 8 * outSize * batchSize);
460
+ MakeAlias(nextDeltaForgetGate, workspace, outSize, activeBatchSize,
461
+ 9 * outSize * batchSize);
462
+ MakeAlias(nextDeltaOutputGate, workspace, outSize, activeBatchSize,
463
+ 10 * outSize * batchSize);
464
+ MakeAlias(nextDeltaCell, workspace, outSize, activeBatchSize,
465
+ 11 * outSize * batchSize);
466
+ }
467
+ else
468
+ {
469
+ MakeAlias(nextDeltaY, workspace, outSize, activeBatchSize);
470
+ MakeAlias(nextDeltaBlockInput, workspace, outSize, activeBatchSize,
471
+ outSize * batchSize);
472
+ MakeAlias(nextDeltaInputGate, workspace, outSize, activeBatchSize,
473
+ 2 * outSize * batchSize);
474
+ MakeAlias(nextDeltaForgetGate, workspace, outSize, activeBatchSize,
475
+ 3 * outSize * batchSize);
476
+ MakeAlias(nextDeltaOutputGate, workspace, outSize, activeBatchSize,
477
+ 4 * outSize * batchSize);
478
+ MakeAlias(nextDeltaCell, workspace, outSize, activeBatchSize,
479
+ 5 * outSize * batchSize);
480
+
481
+ MakeAlias(deltaY, workspace, outSize, activeBatchSize,
482
+ 6 * outSize * batchSize);
483
+ MakeAlias(deltaBlockInput, workspace, outSize, activeBatchSize,
484
+ 7 * outSize * batchSize);
485
+ MakeAlias(deltaInputGate, workspace, outSize, activeBatchSize,
486
+ 8 * outSize * batchSize);
487
+ MakeAlias(deltaForgetGate, workspace, outSize, activeBatchSize,
488
+ 9 * outSize * batchSize);
489
+ MakeAlias(deltaOutputGate, workspace, outSize, activeBatchSize,
490
+ 10 * outSize * batchSize);
491
+ MakeAlias(deltaCell, workspace, outSize, activeBatchSize,
492
+ 11 * outSize * batchSize);
493
+ }
494
+
495
+ if (!this->AtFinalStep())
496
+ {
497
+ // To update the cell state, we actually need to use the forget gate
498
+ // values from the next time step.
499
+ MakeAlias(nextForgetGate, this->RecurrentState(this->CurrentStep() + 1),
500
+ outSize, activeBatchSize, 4 * outSize * batchSize);
501
+ }
498
502
  }
499
503
  }
500
504
 
501
505
  template<typename MatType>
502
506
  template<typename Archive>
503
- void LSTMType<MatType>::serialize(Archive& ar, const uint32_t /* version */)
507
+ void LSTM<MatType>::serialize(Archive& ar, const uint32_t /* version */)
504
508
  {
505
509
  ar(cereal::base_class<RecurrentLayer<MatType>>(this));
506
510
 
@@ -56,12 +56,15 @@ class MaxPoolingRule
56
56
  * computation.
57
57
  */
58
58
  template<typename MatType = arma::mat>
59
- class MaxPoolingType : public Layer<MatType>
59
+ class MaxPooling : public Layer<MatType>
60
60
  {
61
61
  public:
62
+ // Convenience typedefs.
63
+ using ElemType = typename MatType::elem_type;
62
64
  using CubeType = typename GetCubeType<MatType>::type;
63
- //! Create the MaxPooling object.
64
- MaxPoolingType();
65
+
66
+ // Create the MaxPooling object.
67
+ MaxPooling();
65
68
 
66
69
  /**
67
70
  * Create the MaxPooling object using the specified number of units.
@@ -73,26 +76,26 @@ class MaxPoolingType : public Layer<MatType>
73
76
  * @param floor If true, then a pooling operation that would oly part of the
74
77
  * input will be skipped.
75
78
  */
76
- MaxPoolingType(const size_t kernelWidth,
79
+ MaxPooling(const size_t kernelWidth,
77
80
  const size_t kernelHeight,
78
81
  const size_t strideWidth = 1,
79
82
  const size_t strideHeight = 1,
80
83
  const bool floor = true);
81
84
 
82
85
  // Virtual destructor.
83
- virtual ~MaxPoolingType() { }
86
+ virtual ~MaxPooling() { }
84
87
 
85
- //! Copy the given MaxPoolingType.
86
- MaxPoolingType(const MaxPoolingType& other);
87
- //! Take ownership of the given MaxPoolingType.
88
- MaxPoolingType(MaxPoolingType&& other);
89
- //! Copy the given MaxPoolingType.
90
- MaxPoolingType& operator=(const MaxPoolingType& other);
91
- //! Take ownership of the given MaxPoolingType.
92
- MaxPoolingType& operator=(MaxPoolingType&& other);
88
+ //! Copy the given MaxPooling.
89
+ MaxPooling(const MaxPooling& other);
90
+ //! Take ownership of the given MaxPooling.
91
+ MaxPooling(MaxPooling&& other);
92
+ //! Copy the given MaxPooling.
93
+ MaxPooling& operator=(const MaxPooling& other);
94
+ //! Take ownership of the given MaxPooling.
95
+ MaxPooling& operator=(MaxPooling&& other);
93
96
 
94
- //! Clone the MaxPoolingType object. This handles polymorphism correctly.
95
- MaxPoolingType* Clone() const { return new MaxPoolingType(*this); }
97
+ //! Clone the MaxPooling object. This handles polymorphism correctly.
98
+ MaxPooling* Clone() const { return new MaxPooling(*this); }
96
99
 
97
100
  /**
98
101
  * Ordinary feed forward pass of a neural network, evaluating the function
@@ -306,10 +309,7 @@ class MaxPoolingType : public Layer<MatType>
306
309
 
307
310
  //! Locally-stored pooling indices.
308
311
  arma::Cube<size_t> poolingIndices;
309
- }; // class MaxPoolingType
310
-
311
- // Standard MaxPooling layer.
312
- using MaxPooling = MaxPoolingType<arma::mat>;
312
+ }; // class MaxPooling
313
313
 
314
314
  } // namespace mlpack
315
315
 
@@ -19,14 +19,14 @@
19
19
  namespace mlpack {
20
20
 
21
21
  template<typename MatType>
22
- MaxPoolingType<MatType>::MaxPoolingType() :
22
+ MaxPooling<MatType>::MaxPooling() :
23
23
  Layer<MatType>()
24
24
  {
25
25
  // Nothing to do here.
26
26
  }
27
27
 
28
28
  template<typename MatType>
29
- MaxPoolingType<MatType>::MaxPoolingType(
29
+ MaxPooling<MatType>::MaxPooling(
30
30
  const size_t kernelWidth,
31
31
  const size_t kernelHeight,
32
32
  const size_t strideWidth,
@@ -44,8 +44,8 @@ MaxPoolingType<MatType>::MaxPoolingType(
44
44
  }
45
45
 
46
46
  template<typename MatType>
47
- MaxPoolingType<MatType>::MaxPoolingType(
48
- const MaxPoolingType& other) :
47
+ MaxPooling<MatType>::MaxPooling(
48
+ const MaxPooling& other) :
49
49
  Layer<MatType>(other),
50
50
  kernelWidth(other.kernelWidth),
51
51
  kernelHeight(other.kernelHeight),
@@ -59,8 +59,8 @@ MaxPoolingType<MatType>::MaxPoolingType(
59
59
  }
60
60
 
61
61
  template<typename MatType>
62
- MaxPoolingType<MatType>::MaxPoolingType(
63
- MaxPoolingType&& other) :
62
+ MaxPooling<MatType>::MaxPooling(
63
+ MaxPooling&& other) :
64
64
  Layer<MatType>(std::move(other)),
65
65
  kernelWidth(std::move(other.kernelWidth)),
66
66
  kernelHeight(std::move(other.kernelHeight)),
@@ -74,8 +74,8 @@ MaxPoolingType<MatType>::MaxPoolingType(
74
74
  }
75
75
 
76
76
  template<typename MatType>
77
- MaxPoolingType<MatType>&
78
- MaxPoolingType<MatType>::operator=(const MaxPoolingType& other)
77
+ MaxPooling<MatType>&
78
+ MaxPooling<MatType>::operator=(const MaxPooling& other)
79
79
  {
80
80
  if (&other != this)
81
81
  {
@@ -93,8 +93,8 @@ MaxPoolingType<MatType>::operator=(const MaxPoolingType& other)
93
93
  }
94
94
 
95
95
  template<typename MatType>
96
- MaxPoolingType<MatType>&
97
- MaxPoolingType<MatType>::operator=(MaxPoolingType&& other)
96
+ MaxPooling<MatType>&
97
+ MaxPooling<MatType>::operator=(MaxPooling&& other)
98
98
  {
99
99
  if (&other != this)
100
100
  {
@@ -112,7 +112,7 @@ MaxPoolingType<MatType>::operator=(MaxPoolingType&& other)
112
112
  }
113
113
 
114
114
  template<typename MatType>
115
- void MaxPoolingType<MatType>::Forward(const MatType& input, MatType& output)
115
+ void MaxPooling<MatType>::Forward(const MatType& input, MatType& output)
116
116
  {
117
117
  using CubeType = typename GetCubeType<MatType>::type;
118
118
  CubeType inputTemp;
@@ -139,7 +139,7 @@ void MaxPoolingType<MatType>::Forward(const MatType& input, MatType& output)
139
139
  }
140
140
 
141
141
  template<typename MatType>
142
- void MaxPoolingType<MatType>::Backward(
142
+ void MaxPooling<MatType>::Backward(
143
143
  const MatType& input,
144
144
  const MatType& /* output */,
145
145
  const MatType& gy,
@@ -167,7 +167,7 @@ void MaxPoolingType<MatType>::Backward(
167
167
  }
168
168
 
169
169
  template<typename MatType>
170
- void MaxPoolingType<MatType>::ComputeOutputDimensions()
170
+ void MaxPooling<MatType>::ComputeOutputDimensions()
171
171
  {
172
172
  this->outputDimensions = this->inputDimensions;
173
173
 
@@ -197,7 +197,7 @@ void MaxPoolingType<MatType>::ComputeOutputDimensions()
197
197
 
198
198
  template<typename MatType>
199
199
  template<typename Archive>
200
- void MaxPoolingType<MatType>::serialize(
200
+ void MaxPooling<MatType>::serialize(
201
201
  Archive& ar,
202
202
  const uint32_t /* version */)
203
203
 
@@ -26,12 +26,15 @@ namespace mlpack {
26
26
  * computation.
27
27
  */
28
28
  template <typename MatType = arma::mat>
29
- class MeanPoolingType : public Layer<MatType>
29
+ class MeanPooling : public Layer<MatType>
30
30
  {
31
31
  public:
32
+ // Convenience typedefs.
33
+ using ElemType = typename MatType::elem_type;
32
34
  using CubeType = typename GetCubeType<MatType>::type;
33
- //! Create the MeanPoolingType object.
34
- MeanPoolingType();
35
+
36
+ // Create the MeanPooling object.
37
+ MeanPooling();
35
38
 
36
39
  /**
37
40
  * Create the MeanPooling object using the specified number of units.
@@ -43,26 +46,26 @@ class MeanPoolingType : public Layer<MatType>
43
46
  * @param floor If true, then a pooling operation that would oly part of the
44
47
  * input will be skipped.
45
48
  */
46
- MeanPoolingType(const size_t kernelWidth,
47
- const size_t kernelHeight,
48
- const size_t strideWidth = 1,
49
- const size_t strideHeight = 1,
50
- const bool floor = true);
49
+ MeanPooling(const size_t kernelWidth,
50
+ const size_t kernelHeight,
51
+ const size_t strideWidth = 1,
52
+ const size_t strideHeight = 1,
53
+ const bool floor = true);
51
54
 
52
55
  // Virtual destructor.
53
- virtual ~MeanPoolingType() { }
56
+ virtual ~MeanPooling() { }
54
57
 
55
- //! Copy the given MeanPoolingType.
56
- MeanPoolingType(const MeanPoolingType& other);
57
- //! Take ownership of the given MeanPoolingType.
58
- MeanPoolingType(MeanPoolingType&& other);
59
- //! Copy the given MeanPoolingType.
60
- MeanPoolingType& operator=(const MeanPoolingType& other);
61
- //! Take ownership of the given MeanPoolingType.
62
- MeanPoolingType& operator=(MeanPoolingType&& other);
58
+ //! Copy the given MeanPooling.
59
+ MeanPooling(const MeanPooling& other);
60
+ //! Take ownership of the given MeanPooling.
61
+ MeanPooling(MeanPooling&& other);
62
+ //! Copy the given MeanPooling.
63
+ MeanPooling& operator=(const MeanPooling& other);
64
+ //! Take ownership of the given MeanPooling.
65
+ MeanPooling& operator=(MeanPooling&& other);
63
66
 
64
- //! Clone the MeanPoolingType object. This handles polymorphism correctly.
65
- MeanPoolingType* Clone() const { return new MeanPoolingType(*this); }
67
+ //! Clone the MeanPooling object. This handles polymorphism correctly.
68
+ MeanPooling* Clone() const { return new MeanPooling(*this); }
66
69
 
67
70
  /**
68
71
  * Ordinary feed forward pass of a neural network, evaluating the function
@@ -149,7 +152,7 @@ class MeanPoolingType : public Layer<MatType>
149
152
  */
150
153
  typename MatType::elem_type Pooling(const MatType& input)
151
154
  {
152
- return arma::mean(vectorise(input));
155
+ return mean(vectorise(input));
153
156
  }
154
157
 
155
158
  //! Locally-stored width of the pooling window.
@@ -169,10 +172,7 @@ class MeanPoolingType : public Layer<MatType>
169
172
 
170
173
  //! Locally-stored number channels.
171
174
  size_t channels;
172
- }; // class MeanPoolingType
173
-
174
- // Standard MeanPooling layer.
175
- using MeanPooling = MeanPoolingType<arma::mat>;
175
+ }; // class MeanPooling
176
176
 
177
177
  } // namespace mlpack
178
178