mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -34,9 +34,10 @@ class MultiQuadFunction
34
34
  * @param x Input data.
35
35
  * @return f(x).
36
36
  */
37
- static double Fn(const double x)
37
+ template<typename ElemType>
38
+ static ElemType Fn(const ElemType x)
38
39
  {
39
- return std::pow(1 + x * x, 0.5);
40
+ return std::sqrt(1 + x * x);
40
41
  }
41
42
 
42
43
  /**
@@ -48,7 +49,7 @@ class MultiQuadFunction
48
49
  template<typename InputVecType, typename OutputVecType>
49
50
  static void Fn(const InputVecType& x, OutputVecType& y)
50
51
  {
51
- y = pow((1 + pow(x, 2)), 0.5);
52
+ y = sqrt((1 + square(x)));
52
53
  }
53
54
 
54
55
  /**
@@ -61,7 +62,8 @@ class MultiQuadFunction
61
62
  * @param y Result of Fn(x).
62
63
  * @return f'(x)
63
64
  */
64
- static double Deriv(const double x, const double y)
65
+ template<typename ElemType>
66
+ static ElemType Deriv(const ElemType x, const ElemType y)
65
67
  {
66
68
  return x / y;
67
69
  }
@@ -33,7 +33,8 @@ class Poisson1Function
33
33
  * @param x Input data.
34
34
  * @return f(x).
35
35
  */
36
- static double Fn(const double x)
36
+ template<typename ElemType>
37
+ static ElemType Fn(const ElemType x)
37
38
  {
38
39
  return (x - 1) * std::exp(-x);
39
40
  }
@@ -57,7 +58,8 @@ class Poisson1Function
57
58
  * @param y Result of Fn(x).
58
59
  * @return f'(x)
59
60
  */
60
- static double Deriv(const double x, const double /* y */)
61
+ template<typename ElemType>
62
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
61
63
  {
62
64
  return -std::exp(-x) * (x - 2);
63
65
  }
@@ -33,9 +33,10 @@ class QuadraticFunction
33
33
  * @param x Input data.
34
34
  * @return f(x).
35
35
  */
36
- static double Fn(const double x)
36
+ template<typename ElemType>
37
+ static ElemType Fn(const ElemType x)
37
38
  {
38
- return std::pow(x, 2);
39
+ return std::pow(x, ElemType(2));
39
40
  }
40
41
 
41
42
  /**
@@ -47,7 +48,7 @@ class QuadraticFunction
47
48
  template<typename InputVecType, typename OutputVecType>
48
49
  static void Fn(const InputVecType& x, OutputVecType& y)
49
50
  {
50
- y = pow(x, 2);
51
+ y = square(x);
51
52
  }
52
53
 
53
54
  /**
@@ -57,7 +58,8 @@ class QuadraticFunction
57
58
  * @param y Result of Fn(x).
58
59
  * @return f'(x)
59
60
  */
60
- static double Deriv(const double x, const double /* y */)
61
+ template<typename ElemType>
62
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
61
63
  {
62
64
  return 2 * x;
63
65
  }
@@ -50,9 +50,10 @@ class RectifierFunction
50
50
  * @param x Input data.
51
51
  * @return f(x).
52
52
  */
53
- static double Fn(const double x)
53
+ template<typename ElemType>
54
+ static ElemType Fn(const ElemType x)
54
55
  {
55
- return std::max(0.0, x);
56
+ return std::max(ElemType(0), x);
56
57
  }
57
58
 
58
59
  /**
@@ -64,9 +65,7 @@ class RectifierFunction
64
65
  template<typename MatType>
65
66
  static void Fn(const MatType& x, MatType& y)
66
67
  {
67
- y.set_size(size(x));
68
- y.zeros();
69
- y = max(y, x);
68
+ y = clamp(x, 0, std::numeric_limits<typename MatType::elem_type>::max());
70
69
  }
71
70
 
72
71
  /**
@@ -76,9 +75,10 @@ class RectifierFunction
76
75
  * @param y Result of Fn(x).
77
76
  * @return f'(x)
78
77
  */
79
- static double Deriv(const double x, const double /* y */)
78
+ template<typename ElemType>
79
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
80
80
  {
81
- return (double)(x > 0);
81
+ return ElemType(x > 0);
82
82
  }
83
83
 
84
84
  /**
@@ -89,11 +89,11 @@ class RectifierFunction
89
89
  * @param dy The resulting derivatives.
90
90
  */
91
91
  template<typename InputType, typename OutputType, typename DerivType>
92
- static void Deriv(const InputType& x,
93
- const OutputType& /* y */,
92
+ static void Deriv(const InputType& /* x */,
93
+ const OutputType& y,
94
94
  DerivType& dy)
95
95
  {
96
- dy = ConvTo<DerivType>::From(x > 0);
96
+ dy = sign(y);
97
97
  }
98
98
  }; // class RectifierFunction
99
99
 
@@ -49,9 +49,10 @@ class SILUFunction
49
49
  * @param x Input data.
50
50
  * @return f(x).
51
51
  */
52
- static double Fn(const double x)
52
+ template<typename ElemType>
53
+ static ElemType Fn(const ElemType x)
53
54
  {
54
- return x / (1.0 + std::exp(-x));
55
+ return x / (1 + std::exp(-x));
55
56
  }
56
57
 
57
58
  /**
@@ -63,7 +64,7 @@ class SILUFunction
63
64
  template<typename InputVecType, typename OutputVecType>
64
65
  static void Fn(const InputVecType &x, OutputVecType &y)
65
66
  {
66
- y = x / (1.0 + exp(-x));
67
+ y = x / (1 + exp(-x));
67
68
  }
68
69
 
69
70
  /**
@@ -73,11 +74,12 @@ class SILUFunction
73
74
  * @param y Result of Fn(x).
74
75
  * @return f'(x)
75
76
  */
76
- static double Deriv(const double x, const double y)
77
+ template<typename ElemType>
78
+ static double Deriv(const ElemType x, const ElemType y)
77
79
  {
78
80
  // since y = x * sigmoid(x)
79
- double sigmoid = y / x; // save an exp
80
- return x == 0 ? 0.5 : sigmoid * (1.0 + x * (1.0 - sigmoid));
81
+ const ElemType sigmoid = y / x; // save an exp
82
+ return x == 0 ? ElemType(0.5) : sigmoid * (1 + x * (1 - sigmoid));
81
83
  // the expression above is indeterminate at 0, even though
82
84
  // the expression solely in terms of x is defined (= 0.5)
83
85
  }
@@ -97,10 +99,10 @@ class SILUFunction
97
99
  // since y = x * sigmoid(x)
98
100
  // DerivVecType sigmoid = y / x;
99
101
  // dy = sigmoid % (1.0 + x % (1.0 - sigmoid));
100
- dy = (y / x) % (1.0 + x - y);
102
+ dy = (y / x) % (1 + x - y);
101
103
  // the expression above is indeterminate at 0, even though
102
104
  // the expression solely in terms of x is defined (= 0.5)
103
- dy(arma::find(x == 0)).fill(0.5);
105
+ dy(arma::find(x == 0)).fill(typename InputVecType::elem_type(0.5));
104
106
  }
105
107
  }; // class SILUFunction
106
108
 
@@ -48,9 +48,10 @@ class SoftplusFunction
48
48
  * @param x Input data.
49
49
  * @return f(x).
50
50
  */
51
- static double Fn(const double x)
51
+ template<typename ElemType>
52
+ static ElemType Fn(const ElemType x)
52
53
  {
53
- const double val = std::log(1 + std::exp(x));
54
+ const ElemType val = std::log(1 + std::exp(x));
54
55
  if (std::isfinite(val))
55
56
  return val;
56
57
  return x;
@@ -65,7 +66,7 @@ class SoftplusFunction
65
66
  template<typename InputType, typename OutputType>
66
67
  static void Fn(const InputType& x, OutputType& y)
67
68
  {
68
- y.set_size(arma::size(x));
69
+ y.set_size(size(x));
69
70
 
70
71
  for (size_t i = 0; i < x.n_elem; ++i)
71
72
  y(i) = Fn(x(i));
@@ -78,9 +79,10 @@ class SoftplusFunction
78
79
  * @param y Result of Fn(x).
79
80
  * @return f'(x)
80
81
  */
81
- static double Deriv(const double x, const double /* y */)
82
+ template<typename ElemType>
83
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
82
84
  {
83
- return 1.0 / (1 + std::exp(-x));
85
+ return 1 / (1 + std::exp(-x));
84
86
  }
85
87
 
86
88
  /**
@@ -95,7 +97,7 @@ class SoftplusFunction
95
97
  const OutputType& /* y */,
96
98
  DerivType& dy)
97
99
  {
98
- dy = 1.0 / (1 + exp(-x));
100
+ dy = 1 / (1 + exp(-x));
99
101
  }
100
102
 
101
103
  /**
@@ -104,9 +106,10 @@ class SoftplusFunction
104
106
  * @param y Input data.
105
107
  * @return f^{-1}(y)
106
108
  */
107
- static double Inv(const double y)
109
+ template<typename ElemType>
110
+ static ElemType Inv(const ElemType y)
108
111
  {
109
- const double val = std::log(std::exp(y) - 1);
112
+ const ElemType val = std::log(std::exp(y) - 1);
110
113
  if (std::isfinite(val))
111
114
  return val;
112
115
  return y;
@@ -121,7 +124,7 @@ class SoftplusFunction
121
124
  template<typename InputType, typename OutputType>
122
125
  static void Inv(const InputType& y, OutputType& x)
123
126
  {
124
- x.set_size(arma::size(y));
127
+ x.set_size(size(y));
125
128
 
126
129
  for (size_t i = 0; i < y.n_elem; ++i)
127
130
  x(i) = Inv(y(i));
@@ -9,7 +9,7 @@
9
9
  *
10
10
  * @code
11
11
  * @inproceedings{GlorotAISTATS2010,
12
- * title={title={Understanding the difficulty of training deep feedforward
12
+ * title={Understanding the difficulty of training deep feedforward
13
13
  * neural networks},
14
14
  * author={Glorot, Xavier and Bengio, Yoshua},
15
15
  * booktitle={Proceedings of AISTATS 2010},
@@ -34,13 +34,7 @@ namespace mlpack {
34
34
  *
35
35
  * @f{eqnarray*}{
36
36
  * f(x) &=& \frac{x}{1 + |x|} \\
37
- * f'(x) &=& (1 - |f(x)|)^2 \\
38
- * f(x) &=& \left\{
39
- * \begin{array}{lr}
40
- * -\frac{x}{1 - x} & : x \le 0 \\
41
- * \frac{x}{1 + x} & : x > 0
42
- * \end{array}
43
- * \right.
37
+ * f'(x) &=& (1 + |f(x)|)^2 \\
44
38
  * @f}
45
39
  */
46
40
  class SoftsignFunction
@@ -52,11 +46,10 @@ class SoftsignFunction
52
46
  * @param x Input data.
53
47
  * @return f(x).
54
48
  */
55
- static double Fn(const double x)
49
+ template<typename ElemType>
50
+ static ElemType Fn(const ElemType x)
56
51
  {
57
- if (x < DBL_MAX)
58
- return x > -DBL_MAX ? x / (1.0 + std::abs(x)) : -1.0;
59
- return 1.0;
52
+ return x / (1 + std::abs(x));
60
53
  }
61
54
 
62
55
  /**
@@ -68,10 +61,7 @@ class SoftsignFunction
68
61
  template<typename InputVecType, typename OutputVecType>
69
62
  static void Fn(const InputVecType& x, OutputVecType& y)
70
63
  {
71
- y.set_size(arma::size(x));
72
-
73
- for (size_t i = 0; i < x.n_elem; ++i)
74
- y(i) = Fn(x(i));
64
+ y = x / (1 + abs(x));
75
65
  }
76
66
 
77
67
  /**
@@ -81,9 +71,10 @@ class SoftsignFunction
81
71
  * @param y Result of Fn(x).
82
72
  * @return f'(x)
83
73
  */
84
- static double Deriv(const double x, const double /* y */)
74
+ template<typename ElemType>
75
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
85
76
  {
86
- return 1.0 / std::pow(1.0 + std::abs(x), 2);
77
+ return 1 / std::pow(1 + std::abs(x), ElemType(2));
87
78
  }
88
79
 
89
80
  /**
@@ -98,7 +89,7 @@ class SoftsignFunction
98
89
  const OutputVecType& /* y */,
99
90
  DerivVecType& dy)
100
91
  {
101
- dy = 1.0 / pow(1.0 + arma::abs(x), 2);
92
+ dy = 1 / square(1 + abs(x));
102
93
  }
103
94
 
104
95
  /**
@@ -107,12 +98,13 @@ class SoftsignFunction
107
98
  * @param y Input data.
108
99
  * @return f^{-1}(y)
109
100
  */
110
- static double Inv(const double y)
101
+ template<typename ElemType>
102
+ static ElemType Inv(const ElemType y)
111
103
  {
112
104
  if (y > 0)
113
- return y < 1 ? -y / (y - 1) : DBL_MAX;
105
+ return -y / (y - 1);
114
106
  else
115
- return y > -1 ? y / (1 + y) : -DBL_MAX;
107
+ return y / (1 + y);
116
108
  }
117
109
 
118
110
  /**
@@ -124,7 +116,7 @@ class SoftsignFunction
124
116
  template<typename InputVecType, typename OutputVecType>
125
117
  static void Inv(const InputVecType& y, OutputVecType& x)
126
118
  {
127
- x.set_size(arma::size(y));
119
+ x.set_size(size(y));
128
120
 
129
121
  for (size_t i = 0; i < y.n_elem; ++i)
130
122
  x(i) = Inv(y(i));
@@ -34,9 +34,10 @@ class SplineFunction
34
34
  * @param x Input data.
35
35
  * @return f(x).
36
36
  */
37
- static double Fn(const double x)
37
+ template<typename ElemType>
38
+ static ElemType Fn(const ElemType x)
38
39
  {
39
- return std::pow(x, 2) * std::log(1 + x);
40
+ return std::pow(x, ElemType(2)) * std::log(1 + x);
40
41
  }
41
42
 
42
43
  /**
@@ -48,7 +49,7 @@ class SplineFunction
48
49
  template<typename InputVecType, typename OutputVecType>
49
50
  static void Fn(const InputVecType& x, OutputVecType& y)
50
51
  {
51
- y = pow(x, 2) % log(1 + x);
52
+ y = square(x) % log(1 + x);
52
53
  }
53
54
 
54
55
  /**
@@ -58,9 +59,10 @@ class SplineFunction
58
59
  * @param y Result of Fn(x).
59
60
  * @return f'(x)
60
61
  */
61
- static double Deriv(const double x, const double y)
62
+ template<typename ElemType>
63
+ static ElemType Deriv(const ElemType x, const ElemType y)
62
64
  {
63
- return x != 0 ? 2 * y / x + std::pow(x, 2) / (1 + x) : 0;
65
+ return (x != 0) ? (2 * y / x + std::pow(x, ElemType(2)) / (1 + x)) : 0;
64
66
  }
65
67
 
66
68
  /**
@@ -75,10 +77,10 @@ class SplineFunction
75
77
  const OutputVecType& y,
76
78
  DerivVecType& dy)
77
79
  {
78
- dy = 2 * y / x + pow(x, 2) / (1 + x);
80
+ dy = 2 * y / x + square(x) / (1 + x);
79
81
  // the expression above is indeterminate at 0, even though
80
82
  // the expression solely in terms of x is defined (= 0)
81
- dy(arma::find(x == 0)).zeros();
83
+ dy(find(x == 0)).zeros();
82
84
  }
83
85
  }; // class SplineFunction
84
86
 
@@ -36,9 +36,10 @@ class SwishFunction
36
36
  * @param x Input data.
37
37
  * @return f(x).
38
38
  */
39
- static double Fn(const double x)
39
+ template<typename ElemType>
40
+ static ElemType Fn(const ElemType x)
40
41
  {
41
- return x / (1.0 + std::exp(-x));
42
+ return x / (1 + std::exp(-x));
42
43
  }
43
44
 
44
45
  /**
@@ -51,7 +52,7 @@ class SwishFunction
51
52
  static void Fn(const MatType& x, MatType& y,
52
53
  const typename std::enable_if_t<IsMatrix<MatType>::value>* = 0)
53
54
  {
54
- y = x / (1.0 + exp(-x));
55
+ y = x / (1 + exp(-x));
55
56
  }
56
57
 
57
58
  /**
@@ -64,7 +65,7 @@ class SwishFunction
64
65
  static void Fn(const VecType& x, VecType& y,
65
66
  const typename std::enable_if_t<IsVector<VecType>::value>* = 0)
66
67
  {
67
- y.set_size(arma::size(x));
68
+ y.set_size(size(x));
68
69
 
69
70
  for (size_t i = 0; i < x.n_elem; ++i)
70
71
  y(i) = Fn(x(i));
@@ -77,10 +78,11 @@ class SwishFunction
77
78
  * @param y Result of Fn(x).
78
79
  * @return f'(x)
79
80
  */
80
- static double Deriv(const double x, const double y)
81
+ template<typename ElemType>
82
+ static ElemType Deriv(const ElemType x, const ElemType y)
81
83
  {
82
- double sigmoid = y / x; // save an exp
83
- return x == 0 ? 0.5 : sigmoid * (1.0 + x * (1.0 - sigmoid));
84
+ const ElemType sigmoid = y / x; // save an exp
85
+ return (x == 0) ? ElemType(0.5) : sigmoid * (1 + x * (1 - sigmoid));
84
86
  // the expression above is indeterminate at 0, even though
85
87
  // the expression solely in terms of x is defined (= 0.5)
86
88
  }
@@ -97,10 +99,10 @@ class SwishFunction
97
99
  const OutputVecType& y,
98
100
  DerivVecType& dy)
99
101
  {
100
- dy = (y / x) % (1.0 + x - y);
102
+ dy = (y / x) % (1 + x - y);
101
103
  // the expression above is indeterminate at 0, even though
102
104
  // the expression solely in terms of x is defined (= 0.5)
103
- dy(arma::find(x == 0)).fill(0.5);
105
+ dy(find(x == 0)).fill(typename InputVecType::elem_type(0.5));
104
106
  }
105
107
  }; // class SwishFunction
106
108
 
@@ -48,7 +48,8 @@ class TanhExpFunction
48
48
  * @param x Input data.
49
49
  * @return f(x).
50
50
  */
51
- static double Fn(const double x)
51
+ template<typename ElemType>
52
+ static ElemType Fn(const ElemType x)
52
53
  {
53
54
  return x * std::tanh(std::exp(x));
54
55
  }
@@ -62,7 +63,7 @@ class TanhExpFunction
62
63
  template<typename InputVecType, typename OutputVecType>
63
64
  static void Fn(const InputVecType& x, OutputVecType& y)
64
65
  {
65
- y = x % arma::tanh(exp(x));
66
+ y = x % tanh(exp(x));
66
67
  }
67
68
 
68
69
  /**
@@ -72,11 +73,12 @@ class TanhExpFunction
72
73
  * @param y Result of Fn(x).
73
74
  * @return f'(x)
74
75
  */
75
- static double Deriv(const double x, const double y)
76
+ template<typename ElemType>
77
+ static ElemType Deriv(const ElemType x, const ElemType y)
76
78
  {
77
79
  // leverage both y and x
78
- return x == 0 ? std::tanh(1) :
79
- y / x + x * std::exp(x) * (1 - std::pow(y / x, 2));
80
+ return (x == 0) ? std::tanh(1) :
81
+ y / x + x * std::exp(x) * (1 - std::pow(y / x, ElemType(2)));
80
82
  }
81
83
 
82
84
  /**
@@ -92,10 +94,10 @@ class TanhExpFunction
92
94
  DerivVecType& dy)
93
95
  {
94
96
  // leverage both y and x
95
- dy = y / x + x % exp(x) % (1 - pow(y / x, 2));
97
+ dy = y / x + x % exp(x) % (1 - square(y / x));
96
98
  // the expression above is indeterminate at 0, even though
97
99
  // the expression solely in terms of x is defined (= tanh(1))
98
- dy(arma::find(x == 0)).fill(std::tanh(1));
100
+ dy(find(x == 0)).fill(std::tanh(typename InputVecType::elem_type(1)));
99
101
  }
100
102
  }; // class TanhExpFunction
101
103
 
@@ -34,7 +34,8 @@ class TanhFunction
34
34
  * @param x Input data.
35
35
  * @return f(x).
36
36
  */
37
- static double Fn(const double x)
37
+ template<typename ElemType>
38
+ static ElemType Fn(const ElemType x)
38
39
  {
39
40
  return std::tanh(x);
40
41
  }
@@ -48,7 +49,7 @@ class TanhFunction
48
49
  template<typename InputVecType, typename OutputVecType>
49
50
  static void Fn(const InputVecType& x, OutputVecType& y)
50
51
  {
51
- y = arma::tanh(x);
52
+ y = tanh(x);
52
53
  }
53
54
 
54
55
  /**
@@ -58,9 +59,10 @@ class TanhFunction
58
59
  * @param y Result of Fn(x).
59
60
  * @return f'(x)
60
61
  */
61
- static double Deriv(const double /* x */, const double y)
62
+ template<typename ElemType>
63
+ static ElemType Deriv(const ElemType /* x */, const ElemType y)
62
64
  {
63
- return 1 - std::pow(y, 2);
65
+ return 1 - std::pow(y, ElemType(2));
64
66
  }
65
67
 
66
68
  /**
@@ -75,7 +77,7 @@ class TanhFunction
75
77
  const OutputVecType& y,
76
78
  DerivVecType& dy)
77
79
  {
78
- dy = 1 - pow(y, 2);
80
+ dy = 1 - square(y);
79
81
  }
80
82
 
81
83
  /**
@@ -84,7 +86,8 @@ class TanhFunction
84
86
  * @param y Input data.
85
87
  * @return f^{-1}(x)
86
88
  */
87
- static double Inv(const double y)
89
+ template<typename ElemType>
90
+ static ElemType Inv(const ElemType y)
88
91
  {
89
92
  return std::atanh(y);
90
93
  }
@@ -98,7 +101,7 @@ class TanhFunction
98
101
  template<typename InputVecType, typename OutputVecType>
99
102
  static void Inv(const InputVecType& y, OutputVecType& x)
100
103
  {
101
- x = arma::atanh(y);
104
+ x = atanh(y);
102
105
  }
103
106
  }; // class TanhFunction
104
107
 
@@ -28,6 +28,9 @@
28
28
  #include "regularizer/regularizer.hpp"
29
29
 
30
30
  #include "ffn.hpp"
31
+ #include "dag_network.hpp"
31
32
  #include "rnn.hpp"
32
33
 
34
+ #include "models/models.hpp"
35
+
33
36
  #endif