mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -0,0 +1,195 @@
1
+ /**
2
+ * @file methods/ann/layer/gru.hpp
3
+ * @author Sumedh Ghaisas
4
+ * @author Zachary Ng
5
+ *
6
+ * Definition of the GRU layer.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_METHODS_ANN_LAYER_GRU_HPP
14
+ #define MLPACK_METHODS_ANN_LAYER_GRU_HPP
15
+
16
+ #include <mlpack/prereqs.hpp>
17
+
18
+ namespace mlpack {
19
+
20
+ /**
21
+ * An implementation of a gru network layer using the following algorithm.
22
+ *
23
+ * r_t = sigmoid(W_r x_t + U_r y_{t - 1})
24
+ * z_t = sigmoid(W_z x_t + U_z y_{t - 1})
25
+ * h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
26
+ * y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
27
+ *
28
+ * For more information, read the following paper:
29
+ *
30
+ * @code
31
+ * @inproceedings{chung2015gated,
32
+ * title = {Gated Feedback Recurrent Neural Networks},
33
+ * author = {Chung, Junyoung and G{\"u}l{\c{c}}ehre, Caglar and Cho,
34
+ * Kyunghyun and Bengio, Yoshua},
35
+ * booktitle = {ICML},
36
+ * pages = {2067--2075},
37
+ * year = {2015},
38
+ * url = {https://arxiv.org/abs/1502.02367}
39
+ * }
40
+ * @endcode
41
+ *
42
+ * This cell can be used in RNNs.
43
+ *
44
+ * @tparam MatType Type of the input data (arma::colvec, arma::mat,
45
+ * arma::sp_mat or arma::cube).
46
+ */
47
+ template <typename MatType = arma::mat>
48
+ class GRU : public RecurrentLayer<MatType>
49
+ {
50
+ public:
51
+ // Create the GRU object.
52
+ GRU();
53
+
54
+ /**
55
+ * Create the GRU layer object using the specified parameters.
56
+ *
57
+ * @param outSize The number of output units.
58
+ */
59
+ GRU(const size_t outSize);
60
+
61
+ // Clone the GRU object. This handles polymorphism correctly.
62
+ GRU* Clone() const { return new GRU(*this); }
63
+
64
+ // Copy the given GRU object.
65
+ GRU(const GRU& other);
66
+ // Take ownership of the given GRU object's data.
67
+ GRU(GRU&& other);
68
+ // Copy the given GRU object.
69
+ GRU& operator=(const GRU& other);
70
+ // Take ownership of the given GRU object's data.
71
+ GRU& operator=(GRU&& other);
72
+
73
+ virtual ~GRU() { }
74
+
75
+ /**
76
+ * Reset the layer parameter. The method is called to
77
+ * assign the allocated memory to the internal learnable parameters.
78
+ */
79
+ void SetWeights(const MatType& weightsIn);
80
+
81
+ /**
82
+ * Ordinary feed forward pass of a neural network, evaluating the function
83
+ * f(x) by propagating the activity forward through f.
84
+ *
85
+ * @param input Input data used for evaluating the specified function.
86
+ * @param output Resulting output activation.
87
+ */
88
+ void Forward(const MatType& input, MatType& output);
89
+
90
+ /**
91
+ * Ordinary feed backward pass of a neural network, calculating the function
92
+ * f(x) by propagating x backwards trough f. Using the results from the feed
93
+ * forward pass.
94
+ *
95
+ * @param input The input data (x) given to the forward pass.
96
+ * @param output The propagated data (f(x)) resulting from Forward()
97
+ * @param gy Propagated error from next layer.
98
+ * @param g Matrix to store propagated error in for previous layer.
99
+ */
100
+ void Backward(const MatType& /* input */,
101
+ const MatType& output,
102
+ const MatType& gy,
103
+ MatType& g);
104
+
105
+ /*
106
+ * Calculate the gradient using the output delta and the input activation.
107
+ *
108
+ * @param input Original input data provided to Forward().
109
+ * @param error Error as computed by `Backward()`.
110
+ * @param gradient Matrix to store the gradients in.
111
+ */
112
+ void Gradient(const MatType& input,
113
+ const MatType& /* error */,
114
+ MatType& gradient);
115
+
116
+ // Get the parameters.
117
+ MatType const& Parameters() const { return weights; }
118
+ // Modify the parameters.
119
+ MatType& Parameters() { return weights; }
120
+
121
+ // Get the total number of trainable parameters.
122
+ size_t WeightSize() const;
123
+
124
+ // Get the total number of recurrent state parameters.
125
+ size_t RecurrentSize() const;
126
+
127
+ // Given a properly set InputDimensions(), compute the output dimensions.
128
+ void ComputeOutputDimensions()
129
+ {
130
+ inSize = this->inputDimensions[0];
131
+ for (size_t i = 1; i < this->inputDimensions.size(); ++i)
132
+ inSize *= this->inputDimensions[i];
133
+ this->outputDimensions = std::vector<size_t>(this->inputDimensions.size(),
134
+ 1);
135
+
136
+ // The GRU layer flattens its input.
137
+ this->outputDimensions[0] = outSize;
138
+ }
139
+
140
+ // Update the internal aliases of the layer when the step changes.
141
+ void OnStepChanged(const size_t step,
142
+ const size_t batchSize,
143
+ const size_t activeBatchSize,
144
+ const bool backwards);
145
+
146
+ /**
147
+ * Serialize the layer
148
+ */
149
+ template<typename Archive>
150
+ void serialize(Archive& ar, const uint32_t /* version */);
151
+
152
+ private:
153
+ // Locally-stored number of input units.
154
+ size_t inSize;
155
+
156
+ // Locally-stored number of output units.
157
+ size_t outSize;
158
+
159
+ // Locally-stored weight object.
160
+ MatType weights;
161
+
162
+ // Weight aliases for input connections.
163
+ MatType resetGateWeight;
164
+ MatType updateGateWeight;
165
+ MatType hiddenGateWeight;
166
+
167
+ // Weight aliases for recurrent connections.
168
+ MatType recurrentResetGateWeight;
169
+ MatType recurrentUpdateGateWeight;
170
+ MatType recurrentHiddenGateWeight;
171
+
172
+ // Recurrent state aliases.
173
+ MatType resetGate;
174
+ MatType updateGate;
175
+ MatType hiddenGate;
176
+ MatType currentOutput;
177
+ MatType prevOutput;
178
+
179
+ // Backwards workspace
180
+ MatType workspace;
181
+ MatType deltaReset;
182
+ MatType deltaUpdate;
183
+ MatType deltaHidden;
184
+ // These correspond to, e.g., dy_{t + 1}.
185
+ MatType nextDeltaReset;
186
+ MatType nextDeltaUpdate;
187
+ MatType nextDeltaHidden;
188
+ }; // class GRU
189
+
190
+ } // namespace mlpack
191
+
192
+ // Include implementation.
193
+ #include "gru_impl.hpp"
194
+
195
+ #endif
@@ -0,0 +1,325 @@
1
+ /**
2
+ * @file methods/ann/layer/gru_impl.hpp
3
+ * @author Sumedh Ghaisas
4
+ * @author Zachary Ng
5
+ *
6
+ * Implementation of the GRU class, which implements a gru network
7
+ * layer.
8
+ *
9
+ * mlpack is free software; you may redistribute it and/or modify it under the
10
+ * terms of the 3-clause BSD license. You should have received a copy of the
11
+ * 3-clause BSD license along with mlpack. If not, see
12
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13
+ */
14
+ #ifndef MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP
15
+ #define MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP
16
+
17
+ // In case it hasn't yet been included.
18
+ #include "gru.hpp"
19
+
20
+ namespace mlpack {
21
+
22
+ template<typename MatType>
23
+ GRU<MatType>::GRU() :
24
+ RecurrentLayer<MatType>()
25
+ {
26
+ // Nothing to do here.
27
+ }
28
+
29
+ template<typename MatType>
30
+ GRU<MatType>::GRU(const size_t outSize) :
31
+ RecurrentLayer<MatType>(),
32
+ inSize(0),
33
+ outSize(outSize)
34
+ {
35
+ // Nothing to do here.
36
+ }
37
+
38
+ template<typename MatType>
39
+ GRU<MatType>::GRU(const GRU& other) :
40
+ RecurrentLayer<MatType>(other),
41
+ inSize(other.inSize),
42
+ outSize(other.outSize)
43
+ {
44
+ // Nothing to do here.
45
+ }
46
+
47
+ template<typename MatType>
48
+ GRU<MatType>::GRU(GRU&& other) :
49
+ RecurrentLayer<MatType>(std::move(other)),
50
+ inSize(other.inSize),
51
+ outSize(other.outSize)
52
+ {
53
+ // Nothing to do here.
54
+ }
55
+
56
+ template<typename MatType>
57
+ GRU<MatType>& GRU<MatType>::operator=(const GRU& other)
58
+ {
59
+ if (this != &other)
60
+ {
61
+ RecurrentLayer<MatType>::operator=(other);
62
+ inSize = other.inSize;
63
+ outSize = other.outSize;
64
+ }
65
+
66
+ return *this;
67
+ }
68
+
69
+ template<typename MatType>
70
+ GRU<MatType>& GRU<MatType>::operator=(GRU&& other)
71
+ {
72
+ if (this != &other)
73
+ {
74
+ RecurrentLayer<MatType>::operator=(std::move(other));
75
+ inSize = other.inSize;
76
+ outSize = other.outSize;
77
+ }
78
+
79
+ return *this;
80
+ }
81
+
82
+
83
+ template<typename MatType>
84
+ void GRU<MatType>::SetWeights(const MatType& weightsIn)
85
+ {
86
+ MakeAlias(weights, weightsIn, weightsIn.n_rows, weightsIn.n_cols);
87
+
88
+ const size_t inputWeightSize = outSize * inSize;
89
+ MakeAlias(resetGateWeight, weightsIn, outSize, inSize, 0);
90
+ MakeAlias(updateGateWeight, weightsIn, outSize, inSize, inputWeightSize);
91
+ MakeAlias(hiddenGateWeight, weightsIn, outSize, inSize, inputWeightSize * 2);
92
+
93
+ const size_t recurrentWeightOffset = inputWeightSize * 3;
94
+ const size_t recurrentWeightSize = outSize * outSize;
95
+ MakeAlias(recurrentResetGateWeight, weightsIn, outSize, outSize,
96
+ recurrentWeightOffset);
97
+ MakeAlias(recurrentUpdateGateWeight, weightsIn, outSize, outSize,
98
+ recurrentWeightOffset + recurrentWeightSize);
99
+ MakeAlias(recurrentHiddenGateWeight, weightsIn, outSize, outSize,
100
+ recurrentWeightOffset + recurrentWeightSize * 2);
101
+ }
102
+
103
+ template<typename MatType>
104
+ void GRU<MatType>::Forward(const MatType& input, MatType& output)
105
+ {
106
+ // Compute internal state using the following algorithm.
107
+ // r_t = sigmoid(W_r x_t + U_r y_{t - 1})
108
+ // z_t = sigmoid(W_z x_t + U_z y_{t - 1})
109
+ // h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
110
+ // y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
111
+
112
+ // Process non recurrent input.
113
+ updateGate = updateGateWeight * input;
114
+ resetGate = resetGateWeight * input;
115
+
116
+ // Add recurrent input.
117
+ if (this->HasPreviousStep())
118
+ {
119
+ resetGate += recurrentResetGateWeight * prevOutput;
120
+ updateGate += recurrentUpdateGateWeight * prevOutput;
121
+ }
122
+
123
+ // Apply sigmoid activation function.
124
+ resetGate = 1 / (1 + exp(-resetGate));
125
+ updateGate = 1 / (1 + exp(-updateGate));
126
+
127
+ // Calculate candidate activation vector.
128
+ hiddenGate = hiddenGateWeight * input;
129
+
130
+ // Add recurrent portion to activation vector.
131
+ if (this->HasPreviousStep())
132
+ {
133
+ hiddenGate += resetGate % (recurrentHiddenGateWeight * prevOutput);
134
+ }
135
+
136
+ // Apply tanh activation function.
137
+ hiddenGate = tanh(hiddenGate);
138
+
139
+ // Compute output.
140
+ output = updateGate % hiddenGate;
141
+
142
+ // Add recurrent portion to output.
143
+ if (this->HasPreviousStep())
144
+ {
145
+ output += (1 - updateGate) % prevOutput;
146
+ }
147
+
148
+ currentOutput = output;
149
+ }
150
+
151
+ template<typename MatType>
152
+ void GRU<MatType>::Backward(
153
+ const MatType& /* input */,
154
+ const MatType& /* output */,
155
+ const MatType& gy,
156
+ MatType& g)
157
+ {
158
+ // Work backwards to get error at each gate.
159
+ // y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
160
+ // dh_t = dy % z_t
161
+ deltaHidden = gy % updateGate;
162
+ // The hidden gate uses a tanh activation function.
163
+ // The derivative of tanh(x) is actually 1 - tanh^2(x) but
164
+ // tanh has already been applied to hiddenGate in Forward().
165
+ deltaHidden = deltaHidden % (1 - square(hiddenGate));
166
+
167
+ // y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
168
+ // dz_t = dy % h_t - dy % y_{t - 1}
169
+ deltaUpdate = gy % hiddenGate;
170
+ if (this->HasPreviousStep())
171
+ deltaUpdate -= gy % prevOutput;
172
+ // The reset and update gate use sigmoid activation.
173
+ // The derivative is sigmoid(x) * (1 - sigmoid(x)). Since sigmoid has
174
+ // already been applied to the gates, it's just `x * (1 - x)`
175
+ deltaUpdate = deltaUpdate % (updateGate % (1 - updateGate));
176
+
177
+ if (this->HasPreviousStep())
178
+ {
179
+ // h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
180
+ // dr_t = dh_t % (U_h y_{t - 1})
181
+ deltaReset = deltaHidden % (recurrentHiddenGateWeight * prevOutput);
182
+ deltaReset = deltaReset % (resetGate % (1 - resetGate));
183
+ }
184
+ else
185
+ {
186
+ deltaReset.zeros(deltaHidden.n_rows, deltaHidden.n_cols);
187
+ }
188
+
189
+ // Calculate the input error.
190
+ // r_t = sigmoid(W_r x_t + U_r y_{t - 1})
191
+ // z_t = sigmoid(W_z x_t + U_z y_{t - 1})
192
+ // h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
193
+ // dx_t = W_r * dr_t + W_z * dz_t + W_h * dh_t
194
+ g = resetGateWeight.t() * deltaReset +
195
+ updateGateWeight.t() * deltaUpdate +
196
+ hiddenGateWeight.t() * deltaHidden;
197
+ }
198
+
199
+ template<typename MatType>
200
+ void GRU<MatType>::Gradient(
201
+ const MatType& input,
202
+ const MatType& /* error */,
203
+ MatType& gradient)
204
+ {
205
+ size_t offset = 0;
206
+ // Non recurrent reset gate weights.
207
+ gradient.submat(offset, 0, offset + resetGateWeight.n_elem - 1, 0) =
208
+ vectorise(deltaReset * input.t());
209
+ offset += resetGateWeight.n_elem;
210
+ // Non recurrent update gate weights.
211
+ gradient.submat(offset, 0, offset + updateGateWeight.n_elem - 1, 0) =
212
+ vectorise(deltaUpdate * input.t());
213
+ offset += updateGateWeight.n_elem;
214
+ // Non recurrent hidden gate weights.
215
+ gradient.submat(offset, 0, offset + hiddenGateWeight.n_elem - 1, 0) =
216
+ vectorise(deltaHidden * input.t());
217
+ offset += hiddenGateWeight.n_elem;
218
+
219
+ // nextDelta is not set until after the first step.
220
+ if (!this->AtFinalStep())
221
+ {
222
+ // Recurrent reset gate weights.
223
+ gradient.submat(offset, 0, offset + recurrentResetGateWeight.n_elem - 1,
224
+ 0) = vectorise(nextDeltaReset * currentOutput.t());
225
+ offset += recurrentResetGateWeight.n_elem;
226
+ // Recurrent update gate weights.
227
+ gradient.submat(offset, 0, offset + recurrentUpdateGateWeight.n_elem - 1,
228
+ 0) = vectorise(nextDeltaUpdate * currentOutput.t());
229
+ offset += recurrentUpdateGateWeight.n_elem;
230
+ // Recurrent hidden gate weights.
231
+ gradient.submat(offset, 0, offset + recurrentHiddenGateWeight.n_elem - 1,
232
+ 0) = vectorise(nextDeltaHidden * currentOutput.t());
233
+ offset += recurrentHiddenGateWeight.n_elem;
234
+ }
235
+ }
236
+
237
+ template<typename MatType>
238
+ size_t GRU<MatType>::WeightSize() const
239
+ {
240
+ return outSize * inSize * 3 + /* Input weight connections */
241
+ outSize * outSize * 3; /* Recurrent weight connections */
242
+ }
243
+
244
+ template<typename MatType>
245
+ size_t GRU<MatType>::RecurrentSize() const
246
+ {
247
+ // The recurrent state has to store the output, reset gate, update gate,
248
+ // and hidden gate.
249
+ return outSize * 4;
250
+ }
251
+
252
+ template<typename MatType>
253
+ void GRU<MatType>::OnStepChanged(const size_t step,
254
+ const size_t batchSize,
255
+ const size_t activeBatchSize,
256
+ const bool backwards)
257
+ {
258
+ // Make aliases for the internal gate states from the recurrent state.
259
+ MatType& state = this->RecurrentState(step);
260
+
261
+ MakeAlias(currentOutput, state, outSize, activeBatchSize);
262
+ MakeAlias(resetGate, state, outSize, activeBatchSize, outSize * batchSize);
263
+ MakeAlias(updateGate, state, outSize, activeBatchSize, 2 * outSize *
264
+ batchSize);
265
+ MakeAlias(hiddenGate, state, outSize, activeBatchSize, 3 * outSize *
266
+ batchSize);
267
+
268
+ if (this->HasPreviousStep())
269
+ {
270
+ MatType& prevState = this->RecurrentState(this->PreviousStep());
271
+ MakeAlias(prevOutput, prevState, outSize, activeBatchSize);
272
+ }
273
+
274
+ // Also set the workspaces for the backwards pass, if requested.
275
+ if (backwards)
276
+ {
277
+ // We need to hold enough space for two time steps.
278
+ workspace.set_size(6 * outSize, batchSize);
279
+
280
+ if (step % 2 == 0)
281
+ {
282
+ MakeAlias(deltaReset, workspace, outSize, activeBatchSize);
283
+ MakeAlias(deltaUpdate, workspace, outSize, activeBatchSize,
284
+ outSize * batchSize);
285
+ MakeAlias(deltaHidden, workspace, outSize, activeBatchSize,
286
+ 2 * outSize * batchSize);
287
+
288
+ MakeAlias(nextDeltaReset, workspace, outSize, activeBatchSize,
289
+ 3 * outSize * batchSize);
290
+ MakeAlias(nextDeltaUpdate, workspace, outSize, activeBatchSize,
291
+ 4 * outSize * batchSize);
292
+ MakeAlias(nextDeltaHidden, workspace, outSize, activeBatchSize,
293
+ 5 * outSize * batchSize);
294
+ }
295
+ else
296
+ {
297
+ MakeAlias(nextDeltaReset, workspace, outSize, activeBatchSize);
298
+ MakeAlias(nextDeltaUpdate, workspace, outSize, activeBatchSize,
299
+ outSize * batchSize);
300
+ MakeAlias(nextDeltaHidden, workspace, outSize, activeBatchSize,
301
+ 2 * outSize * batchSize);
302
+
303
+ MakeAlias(deltaReset, workspace, outSize, activeBatchSize,
304
+ 3 * outSize * batchSize);
305
+ MakeAlias(deltaUpdate, workspace, outSize, activeBatchSize,
306
+ 4 * outSize * batchSize);
307
+ MakeAlias(deltaHidden, workspace, outSize, activeBatchSize,
308
+ 5 * outSize * batchSize);
309
+ }
310
+ }
311
+ }
312
+
313
+ template<typename MatType>
314
+ template<typename Archive>
315
+ void GRU<MatType>::serialize(Archive& ar, const uint32_t /* version */)
316
+ {
317
+ ar(cereal::base_class<RecurrentLayer<MatType>>(this));
318
+
319
+ ar(CEREAL_NVP(inSize));
320
+ ar(CEREAL_NVP(outSize));
321
+ }
322
+
323
+ } // namespace mlpack
324
+
325
+ #endif
@@ -46,9 +46,12 @@ namespace mlpack {
46
46
  * type to differ from the input type (Default: arma::mat).
47
47
  */
48
48
  template <typename MatType = arma::mat>
49
- class HardTanHType : public Layer<MatType>
49
+ class HardTanH : public Layer<MatType>
50
50
  {
51
51
  public:
52
+ // Convenience typedef to access the element type of the weights and data.
53
+ using ElemType = typename MatType::elem_type;
54
+
52
55
  /**
53
56
  * Create the HardTanH object using the specified parameters. The range
54
57
  * of the linear region can be adjusted by specifying the maxValue and
@@ -57,24 +60,24 @@ class HardTanHType : public Layer<MatType>
57
60
  * @param maxValue Range of the linear region maximum value.
58
61
  * @param minValue Range of the linear region minimum value.
59
62
  */
60
- HardTanHType(const double maxValue = 1, const double minValue = -1);
63
+ HardTanH(const double maxValue = 1, const double minValue = -1);
61
64
 
62
- virtual ~HardTanHType() { }
65
+ virtual ~HardTanH() { }
63
66
 
64
67
  //! Copy the other HardTanH layer
65
- HardTanHType(const HardTanHType& layer);
68
+ HardTanH(const HardTanH& layer);
66
69
 
67
70
  //! Take ownership of the members of the other HardTanH Layer
68
- HardTanHType(HardTanHType&& layer);
71
+ HardTanH(HardTanH&& layer);
69
72
 
70
73
  //! Copy the other HardTanH layer
71
- HardTanHType& operator=(const HardTanHType& layer);
74
+ HardTanH& operator=(const HardTanH& layer);
72
75
 
73
76
  //! Take ownership of the members of the other HardTanH Layer
74
- HardTanHType& operator=(HardTanHType&& layer);
77
+ HardTanH& operator=(HardTanH&& layer);
75
78
 
76
- //! Clone the HardTanHType object. This handles polymorphism correctly.
77
- HardTanHType* Clone() const { return new HardTanHType(*this); }
79
+ //! Clone the HardTanH object. This handles polymorphism correctly.
80
+ HardTanH* Clone() const { return new HardTanH(*this); }
78
81
 
79
82
  /**
80
83
  * Ordinary feed forward pass of a neural network, evaluating the function
@@ -122,12 +125,7 @@ class HardTanHType : public Layer<MatType>
122
125
 
123
126
  //! Minimum value for the HardTanH function.
124
127
  double minValue;
125
- }; // class HardTanHType
126
-
127
- // Convenience typedefs.
128
-
129
- // Standard HardTanH layer.
130
- using HardTanH = HardTanHType<arma::mat>;
128
+ }; // class HardTanH
131
129
 
132
130
  } // namespace mlpack
133
131
 
@@ -19,7 +19,7 @@
19
19
  namespace mlpack {
20
20
 
21
21
  template<typename MatType>
22
- HardTanHType<MatType>::HardTanHType(
22
+ HardTanH<MatType>::HardTanH(
23
23
  const double maxValue,
24
24
  const double minValue) :
25
25
  Layer<MatType>(),
@@ -30,7 +30,7 @@ HardTanHType<MatType>::HardTanHType(
30
30
  }
31
31
 
32
32
  template<typename MatType>
33
- HardTanHType<MatType>::HardTanHType(const HardTanHType& layer) :
33
+ HardTanH<MatType>::HardTanH(const HardTanH& layer) :
34
34
  Layer<MatType>(layer),
35
35
  maxValue(layer.maxValue),
36
36
  minValue(layer.minValue)
@@ -39,7 +39,7 @@ HardTanHType<MatType>::HardTanHType(const HardTanHType& layer) :
39
39
  }
40
40
 
41
41
  template<typename MatType>
42
- HardTanHType<MatType>::HardTanHType(HardTanHType&& layer) :
42
+ HardTanH<MatType>::HardTanH(HardTanH&& layer) :
43
43
  Layer<MatType>(std::move(layer)),
44
44
  maxValue(std::move(layer.maxValue)),
45
45
  minValue(std::move(layer.minValue))
@@ -48,8 +48,8 @@ HardTanHType<MatType>::HardTanHType(HardTanHType&& layer) :
48
48
  }
49
49
 
50
50
  template<typename MatType>
51
- HardTanHType<MatType>& HardTanHType<MatType>::operator=(
52
- const HardTanHType& layer)
51
+ HardTanH<MatType>& HardTanH<MatType>::operator=(
52
+ const HardTanH& layer)
53
53
  {
54
54
  if (&layer != this)
55
55
  {
@@ -62,7 +62,7 @@ HardTanHType<MatType>& HardTanHType<MatType>::operator=(
62
62
  }
63
63
 
64
64
  template<typename MatType>
65
- HardTanHType<MatType>& HardTanHType<MatType>::operator=(HardTanHType&& layer)
65
+ HardTanH<MatType>& HardTanH<MatType>::operator=(HardTanH&& layer)
66
66
  {
67
67
  if (&layer != this)
68
68
  {
@@ -74,19 +74,19 @@ HardTanHType<MatType>& HardTanHType<MatType>::operator=(HardTanHType&& layer)
74
74
  return *this;
75
75
  }
76
76
  template<typename MatType>
77
- void HardTanHType<MatType>::Forward(
77
+ void HardTanH<MatType>::Forward(
78
78
  const MatType& input, MatType& output)
79
79
  {
80
80
  #pragma omp parallel for
81
81
  for (size_t i = 0; i < input.n_elem; ++i)
82
82
  {
83
- output(i) = (input(i) > maxValue ? maxValue :
84
- (input(i) < minValue ? minValue : input(i)));
83
+ output(i) = (input(i) > ElemType(maxValue) ? ElemType(maxValue) :
84
+ (input(i) < ElemType(minValue) ? ElemType(minValue) : input(i)));
85
85
  }
86
86
  }
87
87
 
88
88
  template<typename MatType>
89
- void HardTanHType<MatType>::Backward(
89
+ void HardTanH<MatType>::Backward(
90
90
  const MatType& input,
91
91
  const MatType& /* output */,
92
92
  const MatType& gy,
@@ -99,7 +99,7 @@ void HardTanHType<MatType>::Backward(
99
99
  {
100
100
  // input should not have any values greater than maxValue
101
101
  // and lesser than minValue
102
- if (input(i) <= minValue || input(i) >= maxValue)
102
+ if (input(i) <= ElemType(minValue) || input(i) >= ElemType(maxValue))
103
103
  {
104
104
  g(i) = 0;
105
105
  }
@@ -108,7 +108,7 @@ void HardTanHType<MatType>::Backward(
108
108
 
109
109
  template<typename MatType>
110
110
  template<typename Archive>
111
- void HardTanHType<MatType>::serialize(
111
+ void HardTanH<MatType>::serialize(
112
112
  Archive& ar,
113
113
  const uint32_t /* version */)
114
114
  {