mlpack 4.6.2__cp39-cp39-win_amd64.whl → 4.7.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (417) hide show
  1. mlpack/__init__.py +6 -6
  2. mlpack/adaboost_classify.cp39-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp39-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp39-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp39-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp39-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp39-win_amd64.pyd +0 -0
  8. mlpack/cf.cp39-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp39-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp39-win_amd64.pyd +0 -0
  11. mlpack/det.cp39-win_amd64.pyd +0 -0
  12. mlpack/emst.cp39-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp39-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp39-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp39-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp39-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp39-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp39-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp39-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp39-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp39-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp39-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp39-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp39-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp39-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp39-win_amd64.pyd +0 -0
  365. mlpack/knn.cp39-win_amd64.pyd +0 -0
  366. mlpack/krann.cp39-win_amd64.pyd +0 -0
  367. mlpack/lars.cp39-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp39-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp39-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp39-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp39-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp39-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp39-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp39-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp39-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp39-win_amd64.pyd +0 -0
  377. mlpack/nca.cp39-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp39-win_amd64.pyd +0 -0
  379. mlpack/pca.cp39-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp39-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp39-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp39-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp39-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp39-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp39-win_amd64.pyd +0 -0
  386. mlpack/radical.cp39-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp39-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp39-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp39-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +397 -378
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack.libs/.load-order-mlpack-4.7.0 +2 -0
  395. mlpack/include/mlpack/core/data/format.hpp +0 -31
  396. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  397. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  398. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  399. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  400. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  401. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  402. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  403. mlpack/include/mlpack/core/data/types.hpp +0 -61
  404. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  405. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  406. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  407. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  408. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  412. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  413. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  414. mlpack.libs/.load-order-mlpack-4.6.2 +0 -2
  415. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  416. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  417. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -0,0 +1,197 @@
1
+ /**
2
+ * @file methods/ann/convolution_rules/base_convolution.hpp
3
+ * @author Zachary Ng
4
+ *
5
+ * Base class for convolution rules.
6
+ *
7
+ * mlpack is free software; you may redistribute it and/or modify it under the
8
+ * terms of the 3-clause BSD license. You should have received a copy of the
9
+ * 3-clause BSD license along with mlpack. If not, see
10
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11
+ */
12
+ #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_BASE_CONVOLUTION_HPP
13
+ #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_BASE_CONVOLUTION_HPP
14
+
15
+ #include <mlpack/prereqs.hpp>
16
+ #include <mlpack/core/util/using.hpp>
17
+ #include "border_modes.hpp"
18
+
19
+ namespace mlpack {
20
+
21
+ /**
22
+ * This is an abstract class that contains common functions for convolution.
23
+ * This class allows specification of the type of the border type. The
24
+ * convolution can be computed with the valid border type of the full border
25
+ * type (default).
26
+ *
27
+ * FullConvolution: returns the full two-dimensional convolution.
28
+ * ValidConvolution: returns only those parts of the convolution that are
29
+ * computed without the zero-padded edges.
30
+ *
31
+ * @tparam BorderMode Type of the border mode (FullConvolution or
32
+ * ValidConvolution).
33
+ */
34
+ template<typename BorderMode = FullConvolution>
35
+ class BaseConvolution
36
+ {
37
+ protected:
38
+ /**
39
+ * Apply padding to an input matrix.
40
+ *
41
+ * @param input Input used to perform the convolution.
42
+ * @param filter Filter used to perform the convolution.
43
+ * @param inputPadded Input with padding applied.
44
+ */
45
+ template<typename InMatType, typename FilType, typename Border = BorderMode>
46
+ static void
47
+ PadInput(const InMatType& input,
48
+ const FilType& filter,
49
+ InMatType& inputPadded,
50
+ const size_t dilationW,
51
+ const size_t dilationH,
52
+ const typename std::enable_if_t<IsMatrix<InMatType>::value>* = 0)
53
+ {
54
+ if constexpr (std::is_same_v<Border, ValidConvolution>)
55
+ {
56
+ // Use valid padding (none).
57
+ MakeAlias(inputPadded, input, input.n_rows, input.n_cols);
58
+ }
59
+ else
60
+ {
61
+ // Use full padding
62
+
63
+ // First, compute the necessary padding for the full convolution. It is
64
+ // possible that this might be an overestimate. Note that these variables
65
+ // only hold the padding on one side of the input.
66
+ const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
67
+ const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
68
+ const size_t paddingRows = filterRows - 1;
69
+ const size_t paddingCols = filterCols - 1;
70
+
71
+ // Pad filter and input to the working output shape.
72
+ inputPadded = InMatType(input.n_rows + 2 * paddingRows,
73
+ input.n_cols + 2 * paddingCols);
74
+ inputPadded.submat(paddingRows, paddingCols,
75
+ paddingRows + input.n_rows - 1,
76
+ paddingCols + input.n_cols - 1) = input;
77
+ }
78
+ }
79
+
80
+ /**
81
+ * Apply padding to an input cube.
82
+ *
83
+ * @param input Input used to perform the convolution.
84
+ * @param filter Filter used to perform the convolution.
85
+ * @param inputPadded Input with padding applied.
86
+ */
87
+ template<typename InCubeType, typename FilType, typename Border = BorderMode>
88
+ static void
89
+ PadInput(const InCubeType& input,
90
+ const FilType& filter,
91
+ InCubeType& inputPadded,
92
+ const size_t dilationW,
93
+ const size_t dilationH,
94
+ const typename std::enable_if_t<IsCube<InCubeType>::value>* = 0)
95
+ {
96
+ if constexpr (std::is_same_v<Border, ValidConvolution>)
97
+ {
98
+ // Use valid padding (none).
99
+ MakeAlias(inputPadded, input, input.n_rows, input.n_cols, input.n_slices);
100
+ }
101
+ else
102
+ {
103
+ // Use full padding
104
+
105
+ // First, compute the necessary padding for the full convolution. It is
106
+ // possible that this might be an overestimate. Note that these variables
107
+ // only hold the padding on one side of the input.
108
+ const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
109
+ const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
110
+ const size_t paddingRows = filterRows - 1;
111
+ const size_t paddingCols = filterCols - 1;
112
+
113
+ // Pad filter and input to the working output shape.
114
+ inputPadded = InCubeType(input.n_rows + 2 * paddingRows,
115
+ input.n_cols + 2 * paddingCols, input.n_slices);
116
+ inputPadded.subcube(paddingRows, paddingCols, 0,
117
+ paddingRows + input.n_rows - 1, paddingCols + input.n_cols - 1,
118
+ input.n_slices - 1) = input;
119
+ }
120
+ }
121
+
122
+ /**
123
+ * Initalize the output to the required size.
124
+ *
125
+ * @param inputPadded Input with padding applied.
126
+ * @param filter Filter used to perform the convolution.
127
+ * @param output Output data that contains the results of the convolution.
128
+ * @param dW Stride of filter application in the x direction.
129
+ * @param dH Stride of filter application in the y direction.
130
+ * @param dilationW The dilation factor in x direction.
131
+ * @param dilationH The dilation factor in y direction.
132
+ * @param outSlices The number of slices in the output cube.
133
+ */
134
+ template<typename InMatType, typename FilType, typename OutMatType>
135
+ static void
136
+ InitalizeOutput(const InMatType& inputPadded,
137
+ const FilType& filter,
138
+ OutMatType& output,
139
+ const size_t dW = 1,
140
+ const size_t dH = 1,
141
+ const size_t dilationW = 1,
142
+ const size_t dilationH = 1,
143
+ const size_t /* outSlices */ = 1,
144
+ const typename std::enable_if_t<
145
+ IsMatrix<OutMatType>::value>* = 0)
146
+ {
147
+ // Compute the output size. The filterRows and filterCols computation must
148
+ // take into account the fact that dilation only adds rows or columns
149
+ // *between* filter elements. So, e.g., a dilation of 2 on a kernel size of
150
+ // 3x3 means an effective kernel size of 5x5, *not* 6x6.
151
+ const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
152
+ const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
153
+ const size_t outputRows = (inputPadded.n_rows - filterRows + dW) / dW;
154
+ const size_t outputCols = (inputPadded.n_cols - filterCols + dH) / dH;
155
+ output.zeros(outputRows, outputCols);
156
+ }
157
+
158
+ /**
159
+ * Initalize the output to the required size.
160
+ *
161
+ * @param inputPadded Input with padding applied.
162
+ * @param filter Filter used to perform the convolution.
163
+ * @param output Output data that contains the results of the convolution.
164
+ * @param dW Stride of filter application in the x direction.
165
+ * @param dH Stride of filter application in the y direction.
166
+ * @param dilationW The dilation factor in x direction.
167
+ * @param dilationH The dilation factor in y direction.
168
+ * @param outSlices The number of slices in the output cube.
169
+ */
170
+ template<typename InMatType, typename FilType, typename OutCubeType>
171
+ static void
172
+ InitalizeOutput(const InMatType& inputPadded,
173
+ const FilType& filter,
174
+ OutCubeType& output,
175
+ const size_t dW = 1,
176
+ const size_t dH = 1,
177
+ const size_t dilationW = 1,
178
+ const size_t dilationH = 1,
179
+ const size_t outSlices = 1,
180
+ const typename std::enable_if_t<
181
+ IsCube<OutCubeType>::value>* = 0)
182
+ {
183
+ // Compute the output size. The filterRows and filterCols computation must
184
+ // take into account the fact that dilation only adds rows or columns
185
+ // *between* filter elements. So, e.g., a dilation of 2 on a kernel size of
186
+ // 3x3 means an effective kernel size of 5x5, *not* 6x6.
187
+ const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
188
+ const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
189
+ const size_t outputRows = (inputPadded.n_rows - filterRows + dW) / dW;
190
+ const size_t outputCols = (inputPadded.n_cols - filterCols + dH) / dH;
191
+ output.zeros(outputRows, outputCols, outSlices);
192
+ }
193
+ }; // class BaseConvolution
194
+
195
+ } // namespace mlpack
196
+
197
+ #endif
@@ -14,8 +14,7 @@
14
14
  #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_CONVOLUTION_RULES_HPP
15
15
 
16
16
  #include "border_modes.hpp"
17
- #include "fft_convolution.hpp"
18
17
  #include "naive_convolution.hpp"
19
- #include "svd_convolution.hpp"
18
+ #include "im2col_convolution.hpp"
20
19
 
21
20
  #endif
@@ -0,0 +1,215 @@
1
+ /**
2
+ * @file methods/ann/convolution_rules/im2col_convolution.hpp
3
+ * @author Zachary Ng
4
+ *
5
+ * Implementation of the im2col convolution. This is actually im2row because we
6
+ * use column major order.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_IM2COL_CONVOLUTION_HPP
14
+ #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_IM2COL_CONVOLUTION_HPP
15
+
16
+ #include "base_convolution.hpp"
17
+
18
+ namespace mlpack {
19
+
20
+ /**
21
+ * Computes the two-dimensional convolution. This class allows specification of
22
+ * the type of the border type. The convolution can be computed with the valid
23
+ * border type of the full border type (default).
24
+ *
25
+ * FullConvolution: returns the full two-dimensional convolution.
26
+ * ValidConvolution: returns only those parts of the convolution that are
27
+ * computed without the zero-padded edges.
28
+ *
29
+ * @tparam BorderMode Type of the border mode (FullConvolution or
30
+ * ValidConvolution).
31
+ */
32
+ template<typename BorderMode = FullConvolution>
33
+ class Im2ColConvolution : public BaseConvolution<BorderMode>
34
+ {
35
+ public:
36
+ /**
37
+ * Perform a convolution using 3rd order tensors. Expects that `filter` has
38
+ * `input.n_slices * output.n_slices` slices. The Nth `input.n_slices` filters
39
+ * are applied to all input slices and output to the Nth output slice.
40
+ * eg. 2 input slices: filter 0 applies to input 0, output 0,
41
+ * fil 1 * in 1 = out 0, fil 2 * in 0 = out 1, fil 3 * in 1 = out 1,
42
+ * fil 4 * in 0 = out 2, fil 5 * in 1 = out 2, etc.
43
+ *
44
+ * @param input Input used to perform the convolution.
45
+ * @param filter Filter used to perform the convolution.
46
+ * @param output Output data that contains the results of the convolution.
47
+ * @param dW Stride of filter application in the x direction.
48
+ * @param dH Stride of filter application in the y direction.
49
+ * @param dilationW The dilation factor in x direction.
50
+ * @param dilationH The dilation factor in y direction.
51
+ * @param appending If true, it will not initialize the output. Instead,
52
+ * it will append the results to the output.
53
+ */
54
+ template<typename CubeType>
55
+ static void Convolution(
56
+ const CubeType& input,
57
+ const CubeType& filter,
58
+ CubeType& output,
59
+ const size_t dW = 1,
60
+ const size_t dH = 1,
61
+ const size_t dilationW = 1,
62
+ const size_t dilationH = 1,
63
+ const bool appending = false,
64
+ const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
65
+ {
66
+ using MatType = typename GetDenseMatType<CubeType>::type;
67
+
68
+ CubeType inputPadded;
69
+ Im2ColConvolution::PadInput(input, filter, inputPadded, dilationW,
70
+ dilationH);
71
+
72
+ const size_t inMaps = input.n_slices;
73
+ const size_t outMaps = filter.n_slices / inMaps;
74
+
75
+ if (!appending)
76
+ Im2ColConvolution::InitalizeOutput(inputPadded, filter, output, dW, dH,
77
+ dilationW, dilationH, outMaps);
78
+
79
+ // `im2row` is held transposed.
80
+ MatType im2row(filter.n_rows * filter.n_cols * input.n_slices,
81
+ output.n_rows * output.n_cols, GetFillType<MatType>::none);
82
+ // Arrange im2row so that each row has patches from each input map.
83
+ for (size_t i = 0; i < input.n_slices; ++i)
84
+ {
85
+ Im2Row(inputPadded.slice(i), im2row, filter.n_rows, filter.n_cols,
86
+ filter.n_rows * filter.n_cols * i, dW, dH, dilationW, dilationH);
87
+ }
88
+
89
+ // The filters already have the correct order in memory, just reshape it.
90
+ MatType fil2col;
91
+ MakeAlias(fil2col, filter, filter.n_rows * filter.n_cols * inMaps,
92
+ outMaps);
93
+
94
+ // The output is also already in the correct order.
95
+ MatType tempOutput;
96
+ MakeAlias(tempOutput, output, output.n_rows * output.n_cols, outMaps);
97
+
98
+ tempOutput += trans(im2row) * fil2col;
99
+ }
100
+
101
+ /**
102
+ * Perform a convolution using dense matrix as input and a 3rd order tensors
103
+ * as filter and output.
104
+ *
105
+ * @param input Input used to perform the convolution.
106
+ * @param filter Filter used to perform the convolution.
107
+ * @param output Output data that contains the results of the convolution.
108
+ * @param dW Stride of filter application in the x direction.
109
+ * @param dH Stride of filter application in the y direction.
110
+ * @param dilationW The dilation factor in x direction.
111
+ * @param dilationH The dilation factor in y direction.
112
+ * @param appending If true, it will not initialize the output. Instead,
113
+ * it will append the results to the output.
114
+ */
115
+ template<typename MatType, typename CubeType>
116
+ static void Convolution(
117
+ const MatType& input,
118
+ const CubeType& filter,
119
+ CubeType& output,
120
+ const size_t dW = 1,
121
+ const size_t dH = 1,
122
+ const size_t dilationW = 1,
123
+ const size_t dilationH = 1,
124
+ const bool appending = false,
125
+ const typename std::enable_if_t<IsMatrix<MatType>::value>* = 0,
126
+ const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
127
+ {
128
+ MatType inputPadded;
129
+ Im2ColConvolution::PadInput(input, filter, inputPadded, dilationW,
130
+ dilationH);
131
+
132
+ if (!appending)
133
+ Im2ColConvolution::InitalizeOutput(inputPadded, filter, output, dW, dH,
134
+ dilationW, dilationH, filter.n_slices);
135
+
136
+ // `im2row` is held transposed.
137
+ MatType im2row(filter.n_rows * filter.n_cols, output.n_rows * output.n_cols,
138
+ GetFillType<MatType>::none);
139
+ Im2Row(inputPadded, im2row, filter.n_rows, filter.n_cols, 0, dW, dH,
140
+ dilationW, dilationH);
141
+
142
+ // The filters already have the correct order in memory, just reshape it.
143
+ MatType fil2col;
144
+ MakeAlias(fil2col, filter, filter.n_rows * filter.n_cols, filter.n_slices);
145
+
146
+ // The output is also already in the correct order.
147
+ MatType tempOutput;
148
+ MakeAlias(tempOutput, output, output.n_rows * output.n_cols,
149
+ filter.n_slices);
150
+
151
+ tempOutput += trans(im2row) * fil2col;
152
+ }
153
+ private:
154
+ /**
155
+ * Take an input and convert each patch into columns (held transposed).
156
+ * This function expects that `im2row` has the expected dimensions.
157
+ *
158
+ * @param input Input used to perform the convolution.
159
+ * @param im2row Patches of the input as rows.
160
+ * @param filterRows Number of rows in a filter.
161
+ * @param filterCols Number of columns in a filter.
162
+ * @param startRow The starting row for the input image.
163
+ * @param dW Stride of filter application in the x direction.
164
+ * @param dH Stride of filter application in the y direction.
165
+ * @param dilationW The dilation factor in x direction.
166
+ * @param dilationH The dilation factor in y direction.
167
+ */
168
+ template<typename MatType>
169
+ static void Im2Row(const MatType& input,
170
+ MatType& im2row,
171
+ const size_t filterRows,
172
+ const size_t filterCols,
173
+ const size_t startRow = 0,
174
+ const size_t dW = 1,
175
+ const size_t dH = 1,
176
+ const size_t dilationW = 1,
177
+ const size_t dilationH = 1)
178
+ {
179
+ using UVecType = typename GetURowType<MatType>::type;
180
+
181
+ const size_t dFilterRows = filterRows * dilationW - (dilationW - 1);
182
+ const size_t dFilterCols = filterCols * dilationH - (dilationH - 1);
183
+ const size_t outputRows = (input.n_rows - dFilterRows + dW) / dW;
184
+ const size_t outputCols = (input.n_cols - dFilterCols + dH) / dH;
185
+ const bool useDilation = (dilationW != 1) || (dilationH != 1);
186
+
187
+ size_t outCol = 0;
188
+ const size_t filterElems = filterRows * filterCols;
189
+ MatType colAlias;
190
+ for (size_t j = 0; j < outputCols; j++)
191
+ {
192
+ size_t inCol = j * dH;
193
+ for (size_t i = 0; i < outputRows; i++)
194
+ {
195
+ size_t inRow = i * dW;
196
+ // Use an alias instead of `.col()` to avoid the creation of a
197
+ // temporary subview object.
198
+ MakeAlias(colAlias, im2row, filterElems, 1, outCol * im2row.n_rows +
199
+ startRow);
200
+ if (useDilation)
201
+ colAlias = vectorise(input.submat(linspace<UVecType>(inRow,
202
+ inRow + dFilterRows - 1, filterRows),
203
+ linspace<UVecType>(inCol, inCol + dFilterCols - 1, filterCols)));
204
+ else
205
+ colAlias = vectorise(input.submat(inRow, inCol,
206
+ inRow + filterRows - 1, inCol + filterCols - 1));
207
+ outCol++;
208
+ }
209
+ }
210
+ }
211
+ }; // class Im2ColConvolution
212
+
213
+ } // namespace mlpack
214
+
215
+ #endif