mlpack 4.6.2__cp38-cp38-win_amd64.whl → 4.7.0__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (414) hide show
  1. mlpack/__init__.py +3 -3
  2. mlpack/adaboost_classify.cp38-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp38-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp38-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp38-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp38-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp38-win_amd64.pyd +0 -0
  8. mlpack/cf.cp38-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp38-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp38-win_amd64.pyd +0 -0
  11. mlpack/det.cp38-win_amd64.pyd +0 -0
  12. mlpack/emst.cp38-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp38-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp38-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp38-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp38-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp38-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp38-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp38-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp38-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp38-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp38-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp38-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp38-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp38-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp38-win_amd64.pyd +0 -0
  365. mlpack/knn.cp38-win_amd64.pyd +0 -0
  366. mlpack/krann.cp38-win_amd64.pyd +0 -0
  367. mlpack/lars.cp38-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp38-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp38-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp38-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp38-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp38-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp38-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp38-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp38-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp38-win_amd64.pyd +0 -0
  377. mlpack/nca.cp38-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp38-win_amd64.pyd +0 -0
  379. mlpack/pca.cp38-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp38-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp38-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp38-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp38-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp38-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp38-win_amd64.pyd +0 -0
  386. mlpack/radical.cp38-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp38-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp38-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp38-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +5 -5
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +395 -376
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{.load-order-mlpack-4.6.2 → .load-order-mlpack-4.7.0} +0 -0
@@ -22,8 +22,8 @@
22
22
  namespace mlpack {
23
23
 
24
24
  template <typename MatType, typename RegularizerType>
25
- MultiheadAttentionType<MatType, RegularizerType>::
26
- MultiheadAttentionType() :
25
+ MultiheadAttention<MatType, RegularizerType>::
26
+ MultiheadAttention() :
27
27
  tgtSeqLen(0),
28
28
  srcSeqLen(0),
29
29
  embedDim(0),
@@ -35,11 +35,11 @@ MultiheadAttentionType() :
35
35
  }
36
36
 
37
37
  template <typename MatType, typename RegularizerType>
38
- MultiheadAttentionType<MatType, RegularizerType>::
39
- MultiheadAttentionType(
38
+ MultiheadAttention<MatType, RegularizerType>::
39
+ MultiheadAttention(
40
40
  const size_t tgtSeqLen,
41
41
  const size_t numHeads,
42
- const MatType& attnmask,
42
+ const CubeType& attnmask,
43
43
  const MatType& keypaddingmask,
44
44
  const bool selfAttention) :
45
45
  tgtSeqLen(tgtSeqLen),
@@ -53,7 +53,7 @@ MultiheadAttentionType(
53
53
  }
54
54
 
55
55
  template <typename MatType, typename RegularizerType>
56
- void MultiheadAttentionType<MatType, RegularizerType>::SetWeights(
56
+ void MultiheadAttention<MatType, RegularizerType>::SetWeights(
57
57
  const MatType& weightsIn)
58
58
  {
59
59
  MakeAlias(weights, weightsIn, (4 * embedDim + 4) * embedDim, 1);
@@ -70,7 +70,7 @@ void MultiheadAttentionType<MatType, RegularizerType>::SetWeights(
70
70
  }
71
71
 
72
72
  template <typename MatType, typename RegularizerType>
73
- void MultiheadAttentionType<MatType, RegularizerType>::
73
+ void MultiheadAttention<MatType, RegularizerType>::
74
74
  Forward(const MatType& input, MatType& output)
75
75
  {
76
76
  if (input.n_rows != embedDim *
@@ -122,7 +122,7 @@ Forward(const MatType& input, MatType& output)
122
122
 
123
123
  // The scaling factor sqrt(headDim) is used to prevent exploding values
124
124
  // after dot product i.e. when qProj is multiplied with kProj.
125
- qProj /= std::sqrt(headDim);
125
+ qProj /= ElemType(std::sqrt(headDim));
126
126
 
127
127
  // Split the qProj, kProj and vProj into n heads. That's what Multihead
128
128
  // Attention is.
@@ -131,40 +131,16 @@ Forward(const MatType& input, MatType& output)
131
131
  vProj.reshape(srcSeqLen, headDim, numHeads * batchSize);
132
132
 
133
133
  // Calculate the scores i.e. perform the matrix multiplication operation
134
- // on qProj and kProj. Here score = qProj . kProj'
135
- scores = MultiplyCube2Cube(qProj, kProj, false, true);
136
-
137
- // Apply the attention mask if provided. The attention mask is used to black-
138
- // out future sequences and generally used in Encoder-Decoder attention.
139
- // The attention mask has elements -inf or 0.
140
- // The shape of the attention mask : (tgtSeqLen, srcSeqLen).
141
- if (!attnMask.is_empty())
142
- {
143
- if (attnMask.n_rows != tgtSeqLen || attnMask.n_cols != srcSeqLen)
144
- Log::Fatal << "The size of the 'attn_mask' is not correct.\n";
145
- scores.each_slice() += attnMask;
146
- }
147
-
148
- // Apply the key padding mask when provided. It blacks-out any particular
149
- // word in the sequence.
150
- // The key padding mask has elements -inf or 0
151
- // The shape of keyPaddingMask : (1, srcSeqLen).
152
- if (!keyPaddingMask.is_empty())
153
- {
154
- if (keyPaddingMask.n_rows != 1 || keyPaddingMask.n_cols != srcSeqLen)
155
- Log::Fatal << "The size of the 'keyPaddingMask' is not correct.\n";
156
- scores.each_slice() += repmat(keyPaddingMask, tgtSeqLen, 1);
157
- }
134
+ // on qProj and kProj. Here score = kProj . qProj'
135
+ scores = MultiplyCube2Cube(kProj, qProj, false, true);
158
136
 
159
- for (size_t i = 0; i < numHeads * batchSize; ++i)
160
- {
161
- softmax.Forward(scores.slice(i), scores.slice(i));
162
- }
137
+ // Apply softmax to non-masked elements.
138
+ MaskedForwardSoftmax(scores, numHeads, batchSize, attnMask, keyPaddingMask);
163
139
 
164
140
  // Calculate the attention output i.e. matrix multiplication of softmax
165
141
  // output and vProj.
166
142
  // The shape of attnOutput : (tgtSeqLen, headDim, numHeads * batchSize).
167
- attnOut = MultiplyCube2Cube(scores, vProj, false, false);
143
+ attnOut = MultiplyCube2Cube(scores, vProj, true, false);
168
144
 
169
145
  // Now we will concatenate output of all the heads i.e. we will reshape
170
146
  // attnOut to (tgtSeqLen, embedDim, batchSize).
@@ -173,13 +149,13 @@ Forward(const MatType& input, MatType& output)
173
149
  // The final output is the linear projection of attention output.
174
150
  for (size_t i = 0; i < batchSize; ++i)
175
151
  {
176
- output.col(i) = vectorise(trans(attnOut.slice(i) * outWt
152
+ output.col(i) = vectorise(trans(attnOut.slice(i) * outWt.t()
177
153
  + repmat(outBias, tgtSeqLen, 1)));
178
154
  }
179
155
  }
180
156
 
181
157
  template <typename MatType, typename RegularizerType>
182
- void MultiheadAttentionType<MatType, RegularizerType>::
158
+ void MultiheadAttention<MatType, RegularizerType>::
183
159
  Backward(const MatType& /* input */,
184
160
  const MatType& /* output */,
185
161
  const MatType& gy,
@@ -207,7 +183,7 @@ Backward(const MatType& /* input */,
207
183
  // The shape of gyTemp : (embedDim, tgtSeqLen, batchSize).
208
184
  // The shape of outWt : (embedDim, embedDim).
209
185
  // The shape of the result : (tgtSeqLen, embedDim, batchSize).
210
- gyTemp = MultiplyCube2Mat(gyTemp, outWt, true, true);
186
+ gyTemp = MultiplyCube2Mat(gyTemp, outWt, true, false);
211
187
 
212
188
  // Now since the shape of gyTemp is (tgtSeqLen, embedDim, batchSize). We will
213
189
  // split it into n heads.
@@ -216,9 +192,9 @@ Backward(const MatType& /* input */,
216
192
 
217
193
  // Obtain backpropagted error of value.
218
194
  // Shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
219
- // Shape of scores : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
195
+ // Shape of scores : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
220
196
  // The shape of tmp : (srcSeqLen, headDim, numHeads * batchSize).
221
- CubeType tmp = MultiplyCube2Cube(scores, gyTemp, true, false);
197
+ CubeType tmp = MultiplyCube2Cube(scores, gyTemp, false, false);
222
198
 
223
199
  // Concatenate results of all the attention heads.
224
200
  tmp.reshape(srcSeqLen, embedDim, batchSize);
@@ -239,8 +215,8 @@ Backward(const MatType& /* input */,
239
215
 
240
216
  // The shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
241
217
  // The shape of vProj : (srcSeqLen, headDim, numHeads * batchSize).
242
- // So the new shape of gyTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
243
- gyTemp = MultiplyCube2Cube(gyTemp, vProj, false, true);
218
+ // So the new shape of gyTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
219
+ gyTemp = MultiplyCube2Cube(vProj, gyTemp, false, true);
244
220
 
245
221
  for (size_t i = 0; i < numHeads * batchSize; ++i)
246
222
  {
@@ -251,9 +227,9 @@ Backward(const MatType& /* input */,
251
227
 
252
228
  // Obtain backpropagated error of key.
253
229
  // The shape of qProj : (tgtSeqLen, headDim, numHeads * batchSize).
254
- // The shape of gyTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
230
+ // The shape of gyTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
255
231
  // The new shape of tmp : (srcSeqLen, headDim, numHeads * batchSize).
256
- tmp = MultiplyCube2Cube(gyTemp, qProj, true, false);
232
+ tmp = MultiplyCube2Cube(gyTemp, qProj, false, false);
257
233
 
258
234
  // Concatenate results of all the attention heads.
259
235
  tmp.reshape(srcSeqLen, embedDim, batchSize);
@@ -276,9 +252,10 @@ Backward(const MatType& /* input */,
276
252
 
277
253
  // Obtain backpropagated error of the query.
278
254
  // The shape of kProj : (srcSeqLen, headDim, numHeads * batchSize).
279
- // The shape of gyTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
255
+ // The shape of gyTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
280
256
  // The new shape of tmp : (tgtSeqLen, headDim, numHeads * batchSize).
281
- tmp = MultiplyCube2Cube(gyTemp, kProj) / std::sqrt(headDim);
257
+ tmp = MultiplyCube2Cube(gyTemp, kProj, true, false) /
258
+ ElemType(std::sqrt(headDim));
282
259
 
283
260
  // Concatenate results of all the attention heads.
284
261
  tmp.reshape(tgtSeqLen, embedDim, batchSize);
@@ -300,7 +277,7 @@ Backward(const MatType& /* input */,
300
277
  }
301
278
 
302
279
  template <typename MatType, typename RegularizerType>
303
- void MultiheadAttentionType<MatType, RegularizerType>::
280
+ void MultiheadAttention<MatType, RegularizerType>::
304
281
  Gradient(const MatType& input,
305
282
  const MatType& error,
306
283
  MatType& gradient)
@@ -327,7 +304,7 @@ Gradient(const MatType& input,
327
304
  const size_t wtSize = embedDim * embedDim;
328
305
 
329
306
  // The shape of gradient : (4 * embedDim * embedDim + 4 * embedDim, 1).
330
- gradient.set_size(arma::size(weights));
307
+ gradient.set_size(size(weights));
331
308
 
332
309
  const CubeType q, k, v;
333
310
  MakeAlias(const_cast<CubeType&>(q), input, embedDim, tgtSeqLen, batchSize,
@@ -356,22 +333,23 @@ Gradient(const MatType& input,
356
333
 
357
334
  // Gradient wrt. outWt, i.e. dL/d(outWt). We will take sum of gyTemp along
358
335
  // the slices and vectorise the output.
359
- gradient.rows(3 * wtSize, 4 * wtSize - 1) = vectorise(sum(gyTemp, 2));
336
+ CubeType tmpCube = sum(gyTemp, 2);
337
+ gradient.rows(3 * wtSize, 4 * wtSize - 1) = vectorise(tmpCube.slice(0).t());
360
338
 
361
339
  // Partial derivative wrt. attnOut.
362
340
  // The shape of outWt : (embedDim, embedDim).
363
341
  // The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
364
342
  // The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
365
- gyTemp = MultiplyCube2Mat(errorTemp, outWt, true, true);
343
+ gyTemp = MultiplyCube2Mat(errorTemp, outWt, true, false);
366
344
 
367
345
  // Now we will split it into n heads i.e. reshape it into a cube of shape
368
346
  // (tgtSeqLen, headDim, numHeads * batchSize).
369
347
  gyTemp.reshape(tgtSeqLen, headDim, numHeads * batchSize);
370
348
 
371
349
  // Shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
372
- // Shape of scores : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
350
+ // Shape of scores : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
373
351
  // The new shape of errorTemp : (srcSeqLen, headDim, numHeads * batchSize).
374
- errorTemp = MultiplyCube2Cube(scores, gyTemp, true, false);
352
+ errorTemp = MultiplyCube2Cube(scores, gyTemp, false, false);
375
353
 
376
354
  // Now we will concatenate the propagated errors from all heads i.e. we
377
355
  // will reshape errorTemp to (srcSeqLen, embedDim, batchSize).
@@ -393,22 +371,23 @@ Gradient(const MatType& input,
393
371
 
394
372
  // Now, the shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
395
373
  // The shape of vProj : (srcSeqLen, headDim, numHeads * batchSize).
396
- // The new shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
397
- errorTemp = MultiplyCube2Cube(gyTemp, vProj, false, true);
374
+ // The new shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
375
+ errorTemp = MultiplyCube2Cube(vProj, gyTemp, false, true);
398
376
 
399
377
  for (size_t i = 0; i < numHeads * batchSize; ++i)
400
378
  {
401
- // The shape of scores : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
402
- // The shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
379
+ // The shape of scores : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
380
+ // The shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
403
381
  // The new shape of errorTemp remain same.
404
382
  softmax.Backward({} /* unused */, scores.slice(i), errorTemp.slice(i),
405
383
  errorTemp.slice(i));
406
384
  }
407
385
 
386
+
408
387
  // The shape of qProj : (tgtSeqLen, headDim, numHeads * batchSize).
409
- // The shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
388
+ // The shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
410
389
  // The shape of gyTemp : (srcSeqLen, headDim, numHeads * batchSize).
411
- gyTemp = MultiplyCube2Cube(errorTemp, qProj, true, false);
390
+ gyTemp = MultiplyCube2Cube(errorTemp, qProj, false, false);
412
391
 
413
392
  // We will now conctenate the propagated errors from all heads.
414
393
  // The new shape of gyTemp : (srcSeqLen, embedDim, batchSize).
@@ -429,13 +408,13 @@ Gradient(const MatType& input,
429
408
  gradient.rows(wtSize, 2 * wtSize - 1) = vectorise(sum(gyTemp, 2));
430
409
 
431
410
  // The shape of kProj : (srcSeqLen, headDim, numHeads * batchSize).
432
- // The shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
411
+ // The shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
433
412
  // The shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
434
- gyTemp = MultiplyCube2Cube(errorTemp, kProj, false, false);
413
+ gyTemp = MultiplyCube2Cube(errorTemp, kProj, true, false);
435
414
 
436
415
  // Now, we will concatenate propagated error of all heads.
437
416
  gyTemp.reshape(tgtSeqLen, embedDim, batchSize);
438
- gyTemp /= std::sqrt(headDim);
417
+ gyTemp /= ElemType(std::sqrt(headDim));
439
418
 
440
419
  // Gradient wrt. qBias, i.e. dL/d(qBias). We will take summation over all the
441
420
  // batches of gyTemp and over all the sequences.
@@ -457,7 +436,7 @@ Gradient(const MatType& input,
457
436
 
458
437
  template <typename MatType, typename RegularizerType>
459
438
  template <typename Archive>
460
- void MultiheadAttentionType<MatType, RegularizerType>::
439
+ void MultiheadAttention<MatType, RegularizerType>::
461
440
  serialize(Archive& ar, const uint32_t /* version */)
462
441
  {
463
442
  ar(cereal::base_class<Layer<MatType>>(this));
@@ -492,6 +471,124 @@ serialize(Archive& ar, const uint32_t /* version */)
492
471
  }
493
472
  }
494
473
 
474
+ template<typename MatType, typename RegularizerType>
475
+ void MultiheadAttention<MatType, RegularizerType>::MaskedForwardSoftmax(
476
+ CubeType& scores,
477
+ const size_t numHeads,
478
+ const size_t batchSize,
479
+ const CubeType& attnMask,
480
+ const MatType& keyPaddingMask)
481
+ {
482
+ if (attnMask.empty() && keyPaddingMask.empty())
483
+ {
484
+ // No masking required: we can use the simple implementation.
485
+ for (size_t i = 0; i < scores.n_slices; ++i)
486
+ {
487
+ scores.slice(i) = exp(scores.slice(i).each_row() -
488
+ max(scores.slice(i), 0));
489
+ scores.slice(i).each_row() /= sum(scores.slice(i), 0);
490
+ }
491
+ }
492
+ else if (attnMask.empty() && !keyPaddingMask.empty())
493
+ {
494
+ // There is one key padding mask column for each element in the batch.
495
+ for (size_t i = 0; i < batchSize; ++i)
496
+ {
497
+ for (size_t h = 0; h < numHeads; ++h)
498
+ {
499
+ const size_t s = i * numHeads + h;
500
+
501
+ for (size_t c = 0; c < scores.n_cols; ++c)
502
+ {
503
+ ElemType maxVal = std::numeric_limits<ElemType>::lowest();
504
+ for (size_t r = 0; r < scores.n_rows; ++r)
505
+ if (keyPaddingMask(r, i) >= ElemType(0) && scores(r, c, s) > maxVal)
506
+ maxVal = scores(r, c, s);
507
+
508
+ for (size_t r = 0; r < scores.n_rows; ++r)
509
+ {
510
+ if (keyPaddingMask(r, i) < ElemType(0))
511
+ scores(r, c, s) = ElemType(0);
512
+ else
513
+ scores(r, c, s) = std::exp(scores(r, c, s) - maxVal);
514
+ }
515
+
516
+ if (maxVal != std::numeric_limits<ElemType>::lowest())
517
+ scores.slice(s).col(c) /= sum(scores.slice(s).col(c));
518
+ }
519
+ }
520
+ }
521
+ }
522
+ else if (!attnMask.empty() && keyPaddingMask.empty())
523
+ {
524
+ // There is one attention mask for each element in the batch.
525
+ for (size_t i = 0; i < batchSize; ++i)
526
+ {
527
+ for (size_t h = 0; h < numHeads; ++h)
528
+ {
529
+ const size_t s = i * numHeads + h;
530
+
531
+ for (size_t c = 0; c < scores.n_cols; ++c)
532
+ {
533
+ ElemType maxVal = std::numeric_limits<ElemType>::lowest();
534
+ for (size_t r = 0; r < scores.n_rows; ++r)
535
+ if (attnMask(r, c, i) >= ElemType(0) && scores(r, c, s) > maxVal)
536
+ maxVal = scores(r, c, s);
537
+
538
+ for (size_t r = 0; r < scores.n_rows; ++r)
539
+ {
540
+ if (attnMask(r, c, i) < ElemType(0))
541
+ scores(r, c, s) = ElemType(0);
542
+ else
543
+ scores(r, c, s) = std::exp(scores(r, c, s) - maxVal);
544
+ }
545
+
546
+ if (maxVal != std::numeric_limits<ElemType>::lowest())
547
+ scores.slice(s).col(c) /= sum(scores.slice(s).col(c));
548
+ }
549
+ }
550
+ }
551
+ }
552
+ else // !attnMask.empty() && !keyPaddingMask.empty()
553
+ {
554
+ // There is one key padding mask column for each element in the batch, and
555
+ // one attention mask for each element in the batch.
556
+ for (size_t i = 0; i < batchSize; ++i)
557
+ {
558
+ for (size_t h = 0; h < numHeads; ++h)
559
+ {
560
+ const size_t s = i * numHeads + h;
561
+
562
+ for (size_t c = 0; c < scores.n_cols; ++c)
563
+ {
564
+ ElemType maxVal = std::numeric_limits<ElemType>::lowest();
565
+ for (size_t r = 0; r < scores.n_rows; ++r)
566
+ {
567
+ if (attnMask(r, c, i) >= ElemType(0) &&
568
+ keyPaddingMask(r, i) >= ElemType(0) &&
569
+ scores(r, c, s) > maxVal)
570
+ {
571
+ maxVal = scores(r, c, s);
572
+ }
573
+ }
574
+
575
+ for (size_t r = 0; r < scores.n_rows; ++r)
576
+ {
577
+ if (attnMask(r, c, i) < ElemType(0) ||
578
+ keyPaddingMask(r, i) < ElemType(0))
579
+ scores(r, c, s) = ElemType(0);
580
+ else
581
+ scores(r, c, s) = std::exp(scores(r, c, s) - maxVal);
582
+ }
583
+
584
+ if (maxVal != std::numeric_limits<ElemType>::lowest())
585
+ scores.slice(s).col(c) /= sum(scores.slice(s).col(c));
586
+ }
587
+ }
588
+ }
589
+ }
590
+ }
591
+
495
592
  } // namespace mlpack
496
593
 
497
594
  #endif
@@ -1,4 +1,3 @@
1
- //
2
1
  /**
3
2
  * @filer methods/ann/layer/nearest_interpolation.hpp
4
3
  * @author Andrew Furey
@@ -29,14 +28,18 @@ namespace mlpack {
29
28
  * arma::sp_mat or arma::cube).
30
29
  */
31
30
  template<typename MatType = arma::mat>
32
- class NearestInterpolationType : public Layer<MatType>
31
+ class NearestInterpolation : public Layer<MatType>
33
32
  {
34
33
  public:
34
+ // Convenience typedefs.
35
+ using ElemType = typename MatType::elem_type;
35
36
  using CubeType = typename GetCubeType<MatType>::type;
36
- //! Create the NearestInterpolation object.
37
- NearestInterpolationType();
38
37
 
39
- /**Create NearestInterpolation Object with the same scaleFactor along
38
+ // Create the NearestInterpolation object.
39
+ NearestInterpolation();
40
+
41
+ /**
42
+ * Create NearestInterpolation Object with the same scaleFactor along
40
43
  * each dimension.
41
44
  * NOTE: scaleFactors must be a two element vector, the first element
42
45
  * for scaling the first dimension and the second element for scaling
@@ -44,25 +47,25 @@ class NearestInterpolationType : public Layer<MatType>
44
47
  *
45
48
  * If the input dimensions are n x m x ..., then the output dimensions
46
49
  * will be (n x scaleFactors[0]) x (m x scaleFactors[1]) x ...
47
- *
50
+ *
48
51
  * @param scaleFactor Scale factors to scale each dimension by.
49
52
  */
50
- NearestInterpolationType(const std::vector<double> scaleFactors);
53
+ NearestInterpolation(const std::vector<double> scaleFactors);
51
54
 
52
- NearestInterpolationType* Clone() const {
53
- return new NearestInterpolationType(*this);
55
+ NearestInterpolation* Clone() const {
56
+ return new NearestInterpolation(*this);
54
57
  }
55
58
 
56
- virtual ~NearestInterpolationType() { }
59
+ virtual ~NearestInterpolation() { }
57
60
 
58
- //! Copy the given NearestInterpolationType layer.
59
- NearestInterpolationType(const NearestInterpolationType& other);
60
- //! Take ownership of the given NearestInterpolationType layer.
61
- NearestInterpolationType(NearestInterpolationType&& other);
62
- //! Copy the given NearestInterpolationType layer.
63
- NearestInterpolationType& operator=(const NearestInterpolationType& other);
64
- //! Take ownership of the given NearestInterpolationType layer.
65
- NearestInterpolationType& operator=(NearestInterpolationType&& other);
61
+ //! Copy the given NearestInterpolation layer.
62
+ NearestInterpolation(const NearestInterpolation& other);
63
+ //! Take ownership of the given NearestInterpolation layer.
64
+ NearestInterpolation(NearestInterpolation&& other);
65
+ //! Copy the given NearestInterpolation layer.
66
+ NearestInterpolation& operator=(const NearestInterpolation& other);
67
+ //! Take ownership of the given NearestInterpolation layer.
68
+ NearestInterpolation& operator=(NearestInterpolation&& other);
66
69
 
67
70
  /**
68
71
  * Forward pass through the layer. The layer interpolates
@@ -81,12 +84,14 @@ class NearestInterpolationType : public Layer<MatType>
81
84
  * the input size.
82
85
  *
83
86
  * @param * (input) The input matrix.
84
- * @param gradient The computed backward gradient.
85
- * @param output The resulting down-sampled output.
87
+ * @param * (output) The output matrix.
88
+ * @param gy The computed backward gradient.
89
+ * @param g The resulting down-sampled output.
86
90
  */
87
- void Backward(const MatType& /*input*/,
88
- const MatType& gradient,
89
- MatType& output);
91
+ void Backward(const MatType& /* input */,
92
+ const MatType& /* output */,
93
+ const MatType& gy,
94
+ MatType& g);
90
95
 
91
96
  //! Compute the output dimensions of the layer, based on the internal values
92
97
  //! of `InputDimensions()`.
@@ -103,8 +108,6 @@ class NearestInterpolationType : public Layer<MatType>
103
108
  std::vector<double> scaleFactors;
104
109
  }; // class NearestInterpolation
105
110
 
106
- using NearestInterpolation = NearestInterpolationType<arma::mat>;
107
-
108
111
  } // namespace mlpack
109
112
 
110
113
  // Include implementation.
@@ -19,16 +19,16 @@
19
19
  namespace mlpack {
20
20
 
21
21
  template<typename MatType>
22
- NearestInterpolationType<MatType>::NearestInterpolationType():
23
- Layer<MatType>()
22
+ NearestInterpolation<MatType>::NearestInterpolation():
23
+ Layer<MatType>()
24
24
  {
25
25
  // Nothing to do here.
26
26
  }
27
27
 
28
28
  template<typename MatType>
29
- NearestInterpolationType<MatType>::
30
- NearestInterpolationType(const std::vector<double> scaleFactors) :
31
- Layer<MatType>()
29
+ NearestInterpolation<MatType>::
30
+ NearestInterpolation(const std::vector<double> scaleFactors) :
31
+ Layer<MatType>()
32
32
  {
33
33
  if (scaleFactors.size() != 2) {
34
34
  throw std::runtime_error("Scale factors must have 2 dimensions");
@@ -37,27 +37,27 @@ NearestInterpolationType(const std::vector<double> scaleFactors) :
37
37
  }
38
38
 
39
39
  template<typename MatType>
40
- NearestInterpolationType<MatType>::
41
- NearestInterpolationType(const NearestInterpolationType& other) :
42
- Layer<MatType>(),
43
- scaleFactors(other.scaleFactors)
40
+ NearestInterpolation<MatType>::
41
+ NearestInterpolation(const NearestInterpolation& other) :
42
+ Layer<MatType>(),
43
+ scaleFactors(other.scaleFactors)
44
44
  {
45
45
  // Nothing to do here.
46
46
  }
47
47
 
48
48
  template<typename MatType>
49
- NearestInterpolationType<MatType>::
50
- NearestInterpolationType(NearestInterpolationType&& other) :
51
- Layer<MatType>(std::move(other)),
52
- scaleFactors(std::move(other.scaleFactors))
49
+ NearestInterpolation<MatType>::
50
+ NearestInterpolation(NearestInterpolation&& other) :
51
+ Layer<MatType>(std::move(other)),
52
+ scaleFactors(std::move(other.scaleFactors))
53
53
  {
54
54
  // Nothing to do here.
55
55
  }
56
56
 
57
57
  template<typename MatType>
58
- NearestInterpolationType<MatType>&
59
- NearestInterpolationType<MatType>::
60
- operator=(const NearestInterpolationType& other)
58
+ NearestInterpolation<MatType>&
59
+ NearestInterpolation<MatType>::
60
+ operator=(const NearestInterpolation& other)
61
61
  {
62
62
  if (&other != this)
63
63
  {
@@ -68,9 +68,9 @@ operator=(const NearestInterpolationType& other)
68
68
  }
69
69
 
70
70
  template<typename MatType>
71
- NearestInterpolationType<MatType>&
72
- NearestInterpolationType<MatType>::
73
- operator=(NearestInterpolationType&& other)
71
+ NearestInterpolation<MatType>&
72
+ NearestInterpolation<MatType>::
73
+ operator=(NearestInterpolation&& other)
74
74
  {
75
75
  if (&other != this)
76
76
  {
@@ -81,8 +81,8 @@ operator=(NearestInterpolationType&& other)
81
81
  }
82
82
 
83
83
  template<typename MatType>
84
- void NearestInterpolationType<MatType>::Forward(
85
- const MatType& input, MatType& output)
84
+ void NearestInterpolation<MatType>::Forward(
85
+ const MatType& input, MatType& output)
86
86
  {
87
87
  const size_t channels = this->inputDimensions[2];
88
88
 
@@ -100,7 +100,7 @@ void NearestInterpolationType<MatType>::Forward(
100
100
 
101
101
  for (size_t i = 0; i < outRowSize; ++i)
102
102
  {
103
- size_t rOrigin = std::floor(i / scaleFactors[0]);
103
+ size_t rOrigin = std::floor(i / scaleFactors[0]);
104
104
  for (size_t j = 0; j < outColSize; ++j)
105
105
  {
106
106
  size_t cOrigin = std::floor(j / scaleFactors[1]);
@@ -113,10 +113,11 @@ void NearestInterpolationType<MatType>::Forward(
113
113
  }
114
114
 
115
115
  template<typename MatType>
116
- void NearestInterpolationType<MatType>::Backward(
117
- const MatType& /*input*/,
118
- const MatType& gradient,
119
- MatType& output)
116
+ void NearestInterpolation<MatType>::Backward(
117
+ const MatType& /* input */,
118
+ const MatType& /* output */,
119
+ const MatType& gy,
120
+ MatType& g)
120
121
  {
121
122
  const size_t channels = this->inputDimensions[2];
122
123
 
@@ -126,12 +127,11 @@ void NearestInterpolationType<MatType>::Backward(
126
127
  const size_t inRowSize = this->inputDimensions[0];
127
128
  const size_t inColSize = this->inputDimensions[1];
128
129
 
129
- CubeType outputAsCube;
130
- CubeType gradientAsCube;
130
+ CubeType gTemp;
131
+ CubeType gyTemp;
131
132
 
132
- MakeAlias(outputAsCube, output, inRowSize, inColSize, channels, 0, true);
133
- MakeAlias(gradientAsCube, gradient, outRowSize, outColSize, channels, 0,
134
- false);
133
+ MakeAlias(gTemp, g, inRowSize, inColSize, channels, 0);
134
+ MakeAlias(gyTemp, gy, outRowSize, outColSize, channels, 0);
135
135
 
136
136
  for (size_t i = 0; i < outRowSize; ++i)
137
137
  {
@@ -140,15 +140,13 @@ void NearestInterpolationType<MatType>::Backward(
140
140
  {
141
141
  size_t cOrigin = std::floor(j / scaleFactors[1]);
142
142
  for (size_t k = 0; k < channels; ++k)
143
- {
144
- outputAsCube(rOrigin, cOrigin, k) += gradientAsCube(i, j, k);
145
- }
143
+ gTemp(rOrigin, cOrigin, k) += gyTemp(i, j, k);
146
144
  }
147
145
  }
148
146
  }
149
147
 
150
148
  template<typename MatType>
151
- void NearestInterpolationType<MatType>::ComputeOutputDimensions()
149
+ void NearestInterpolation<MatType>::ComputeOutputDimensions()
152
150
  {
153
151
  if (this->inputDimensions.size() < scaleFactors.size())
154
152
  {
@@ -168,9 +166,10 @@ void NearestInterpolationType<MatType>::ComputeOutputDimensions()
168
166
 
169
167
  template<typename MatType>
170
168
  template<typename Archive>
171
- void NearestInterpolationType<MatType>::serialize(
172
- Archive& ar, const uint32_t /* version */)
169
+ void NearestInterpolation<MatType>::serialize(
170
+ Archive& ar, const uint32_t /* version */)
173
171
  {
172
+ ar(cereal::base_class<Layer<MatType>>(this));
174
173
  ar(CEREAL_NVP(scaleFactors));
175
174
  }
176
175