mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -54,9 +54,10 @@ class ElishFunction
54
54
  * @param x Input data.
55
55
  * @return f(x).
56
56
  */
57
- static double Fn(const double x)
57
+ template<typename ElemType>
58
+ static ElemType Fn(const ElemType x)
58
59
  {
59
- if (x < 0.0)
60
+ if (x < 0)
60
61
  return (std::exp(x) - 1) / (1 + std::exp(-x));
61
62
 
62
63
  return x / (1 + std::exp(-x));
@@ -71,8 +72,8 @@ class ElishFunction
71
72
  template<typename InputVecType, typename OutputVecType>
72
73
  static void Fn(const InputVecType& x, OutputVecType& y)
73
74
  {
74
- y = ((x < 0.0) % ((exp(x) - 1) / (1 + exp(-x))))
75
- + ((x >= 0.0) % (x / (1 + exp(-x))));
75
+ y = (conv_to<InputVecType>::from(x < 0) % ((exp(x) - 1) / (1 + exp(-x))))
76
+ + (conv_to<InputVecType>::from(x >= 0) % (x / (1 + exp(-x))));
76
77
  }
77
78
 
78
79
  /**
@@ -82,17 +83,19 @@ class ElishFunction
82
83
  * @param y Result of Fn(x).
83
84
  * @return f'(x).
84
85
  */
85
- static double Deriv(const double x, const double y)
86
+ template<typename ElemType>
87
+ static ElemType Deriv(const ElemType x, const ElemType y)
86
88
  {
87
- if (x < 0.0)
89
+ if (x < 0)
88
90
  {
89
91
  return std::exp(x) - 2 / (1 + std::exp(x)) +
90
92
  2 / std::pow(1 + std::exp(x) , 2);
91
93
  }
92
94
  else if (x == 0)
93
95
  {
94
- return 0.5; // the expression below is indeterminate at 0, even though
95
- // the expression solely in terms of x is defined (= 0.5)
96
+ // The expression below is indeterminate at 0, even though the expression
97
+ // solely in terms of x is defined (= 0.5).
98
+ return ElemType(0.5);
96
99
  }
97
100
  else
98
101
  {
@@ -118,12 +121,14 @@ class ElishFunction
118
121
  // the expression solely in terms of x is defined (= 0.5)
119
122
  // only calculate exp(x) once for each element where x < 0
120
123
  // this gives approx 3x speedup, despite allocating the temp vector
121
- DerivVecType ex = (x < 0) % exp(x);
122
- dy = ((x < 0) % ((ex - 2 / (1 + ex) + 2 / pow(1 + ex, 2)))) +
123
- ((x > 0) % ((y / x) % (1.0 + x - y)));
124
+ DerivVecType ex = conv_to<DerivVecType>::from(x < 0) % exp(x);
125
+ dy = (conv_to<InputVecType>::from(x < 0) %
126
+ ((ex - 2 / (1 + ex) + 2 / square(1 + ex)))) +
127
+ (conv_to<InputVecType>::from(x > 0) %
128
+ ((y / x) % (1 + x - y)));
124
129
  // need to do this here, because the /x above gives nans even when the
125
130
  // condition is not met (e.g. when x > 0 is false)
126
- dy(arma::find(x == 0)).fill(0.5);
131
+ dy(arma::find(x == 0)).fill(typename InputVecType::elem_type(0.5));
127
132
  }
128
133
  }; // class ElishFunction
129
134
 
@@ -45,9 +45,10 @@ class ElliotFunction
45
45
  * @param x Input data.
46
46
  * @return f(x).
47
47
  */
48
- static double Fn(const double x)
48
+ template<typename ElemType>
49
+ static ElemType Fn(const ElemType x)
49
50
  {
50
- return x / (1.0 + std::abs(x));
51
+ return x / (1 + std::abs(x));
51
52
  }
52
53
 
53
54
  /**
@@ -56,10 +57,10 @@ class ElliotFunction
56
57
  * @param x Input data.
57
58
  * @param y The resulting output activation.
58
59
  */
59
- template <typename InputVecType, typename OutputVecType>
60
+ template<typename InputVecType, typename OutputVecType>
60
61
  static void Fn(const InputVecType &x, OutputVecType &y)
61
62
  {
62
- y = x / (1.0 + arma::abs(x));
63
+ y = x / (1 + arma::abs(x));
63
64
  }
64
65
 
65
66
  /**
@@ -69,9 +70,10 @@ class ElliotFunction
69
70
  * @param y Result of Fn(x).
70
71
  * @return f'(x).
71
72
  */
72
- static double Deriv(const double x, const double /* y */)
73
+ template<typename ElemType>
74
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
73
75
  {
74
- return 1.0 / std::pow(1.0 + std::abs(x), 2);
76
+ return 1 / std::pow(1 + std::abs(x), ElemType(2));
75
77
  }
76
78
 
77
79
  /**
@@ -86,7 +88,7 @@ class ElliotFunction
86
88
  const OutputVecType& /* y */,
87
89
  DerivVecType &dy)
88
90
  {
89
- dy = 1.0 / pow(1.0 + arma::abs(x), 2);
91
+ dy = 1 / square(1 + abs(x));
90
92
  }
91
93
  }; // class ElliotFunction
92
94
 
@@ -22,7 +22,7 @@ namespace mlpack {
22
22
  *
23
23
  * @f{eqnarray*}{
24
24
  * f(x) &=& e^{-1 * x^2} \\
25
- * f'(x) &=& 2 * -x * f(x)
25
+ * f'(x) &=& 2 * -x * f(x)
26
26
  * @f}
27
27
  */
28
28
  class GaussianFunction
@@ -34,10 +34,10 @@ class GaussianFunction
34
34
  * @param x Input data.
35
35
  * @return f(x).
36
36
  */
37
- template<typename eT>
38
- static double Fn(const eT x)
37
+ template<typename ElemType>
38
+ static ElemType Fn(const ElemType x)
39
39
  {
40
- return std::exp(-1 * std::pow(x, 2));
40
+ return std::exp(-std::pow(x, ElemType(2)));
41
41
  }
42
42
 
43
43
  /**
@@ -49,7 +49,7 @@ class GaussianFunction
49
49
  template<typename InputVecType, typename OutputVecType>
50
50
  static void Fn(const InputVecType& x, OutputVecType& y)
51
51
  {
52
- y = exp(-1 * pow(x, 2));
52
+ y = exp(-square(x));
53
53
  }
54
54
 
55
55
  /**
@@ -59,7 +59,8 @@ class GaussianFunction
59
59
  * @param y Result of Fn(x).
60
60
  * @return f'(x)
61
61
  */
62
- static double Deriv(const double x, const double y)
62
+ template<typename ElemType>
63
+ static ElemType Deriv(const ElemType x, const ElemType y)
63
64
  {
64
65
  return -2 * x * y;
65
66
  }
@@ -0,0 +1,73 @@
1
+ /**
2
+ * @file methods/ann/activation_functions/gelu_exact_function.hpp
3
+ * @author Kumar Utkarsh
4
+ *
5
+ * Definition and implementation of the exact Gaussian Error Linear Unit (GELU)
6
+ * function.
7
+ *
8
+ * mlpack is free software; you may redistribute it and/or modify it under the
9
+ * terms of the 3-clause BSD license. You should have received a copy of the
10
+ * 3-clause BSD license along with mlpack. If not, see
11
+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12
+ */
13
+ #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_EXACT_FUNCTION_HPP
14
+ #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_EXACT_FUNCTION_HPP
15
+
16
+ #include <mlpack/prereqs.hpp>
17
+
18
+ namespace mlpack {
19
+
20
+ /**
21
+ * The exact GELU function, defined by
22
+ *
23
+ * @f{eqnarray*}{
24
+ * f(x) = x * Phi(x) \\
25
+ * Phi(x) = 0.5 * (1 + erf(x / sqrt(2))) \\
26
+ * f'(x) = Phi(x) + x * phi(x) \\
27
+ * phi(x) = (1 / sqrt(2\pi)) * exp(-x^2 / 2)
28
+ * @f}
29
+ */
30
+ class GELUExactFunction
31
+ {
32
+ public:
33
+ //! Compute the exact GELU function for a single value.
34
+ static double Fn(const double x)
35
+ {
36
+ return 0.5 * x * (1.0 + std::erf(x / std::sqrt(2.0)));
37
+ }
38
+
39
+ //! Compute the exact GELU function for matrices/vectors.
40
+ template<typename InputVecType, typename OutputVecType>
41
+ static void Fn(const InputVecType& x, OutputVecType& y)
42
+ {
43
+ y = 0.5 * x % (1.0 + erf(x / std::sqrt(2.0)));
44
+ }
45
+
46
+ // Compute the first derivative of the exact GELU function for a single value
47
+ static double Deriv(const double x, const double y )
48
+ {
49
+ const double phi = std::exp(-0.5 * x * x) / std::sqrt(2.0 * M_PI);
50
+ // Reuse y to avoid costly Phi(x) computation.
51
+ return (x == 0.0) ? 0.5 : (y / x + x * phi);
52
+ }
53
+
54
+ //! Compute the first derivative for matrices/vectors.
55
+ template<typename InputVecType, typename OutputVecType, typename DerivVecType>
56
+ static void Deriv(const InputVecType& x,
57
+ const OutputVecType& y,
58
+ DerivVecType& dy)
59
+ {
60
+ dy.set_size(x.n_elem);
61
+ // Reuse y to avoid costly Phi(x) computation.
62
+ for (size_t i = 0; i < x.n_elem; ++i)
63
+ {
64
+ if (x[i] == 0.0) dy[i] = 0.5;
65
+ else dy[i] = y[i] / x[i] +
66
+ x[i] * std::exp(-0.5 * x[i] * x[i]) / std::sqrt(2.0 * M_PI);
67
+ }
68
+ }
69
+ }; // class GELUExactFunction
70
+
71
+ } // namespace mlpack
72
+
73
+ #endif
@@ -37,10 +37,12 @@ class GELUFunction
37
37
  * @param x Input data.
38
38
  * @return f(x).
39
39
  */
40
- static double Fn(const double x)
40
+ template<typename ElemType>
41
+ static ElemType Fn(const ElemType x)
41
42
  {
42
- return 0.5 * x * (1 + std::tanh(std::sqrt(2 / M_PI) *
43
- (x + 0.044715 * std::pow(x, 3))));
43
+ return (x / 2) *
44
+ (1 + std::tanh(std::sqrt(2 / arma::Datum<ElemType>::pi) *
45
+ (x + ElemType(0.044715) * std::pow(x, ElemType(3)))));
44
46
  }
45
47
 
46
48
  /**
@@ -52,8 +54,11 @@ class GELUFunction
52
54
  template<typename InputVecType, typename OutputVecType>
53
55
  static void Fn(const InputVecType& x, OutputVecType& y)
54
56
  {
55
- y = 0.5 * x % (1 + arma::tanh(std::sqrt(2 / M_PI) *
56
- (x + 0.044715 * pow(x, 3))));
57
+ typedef typename InputVecType::elem_type ElemType;
58
+
59
+ y = (x / 2) %
60
+ (1 + tanh(std::sqrt(2 / arma::Datum<ElemType>::pi) *
61
+ (x + ElemType(0.044715) * pow(x, ElemType(3)))));
57
62
  }
58
63
 
59
64
  /**
@@ -63,13 +68,16 @@ class GELUFunction
63
68
  * @param y Result of Fn(x).
64
69
  * @return f'(x)
65
70
  */
66
- static double Deriv(const double x, const double /* y */)
71
+ template<typename ElemType>
72
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
67
73
  {
68
- if (x < -10) return 0.0; // catch overflows
69
- return 0.5 * std::tanh(0.0356774 * std::pow(x, 3) + 0.797885 * x) +
70
- (0.0535161 * std::pow(x, 3) + 0.398942 * x) *
71
- std::pow(1 / std::cosh(0.0356774 * std::pow(x, 3) +
72
- 0.797885 * x), 2) + 0.5;
74
+ if (x < -10) return 0; // catch overflows
75
+ return ElemType(0.5) * std::tanh(ElemType(0.0356774) *
76
+ std::pow(x, ElemType(3)) + ElemType(0.797885) * x) +
77
+ (ElemType(0.0535161) * std::pow(x, ElemType(3)) +
78
+ ElemType(0.398942) * x) *
79
+ std::pow(1 / std::cosh(ElemType(0.0356774) * std::pow(x, 3) +
80
+ ElemType(0.797885) * x), 2) + ElemType(0.5);
73
81
  }
74
82
 
75
83
  /**
@@ -84,11 +92,14 @@ class GELUFunction
84
92
  const OutputVecType& /* y */,
85
93
  DerivVecType& dy)
86
94
  {
87
- dy = 0.5 * arma::tanh(0.0356774 * pow(x, 3) + 0.797885 * x) +
88
- (0.0535161 * pow(x, 3) + 0.398942 * x) %
89
- pow(1 / arma::cosh(0.0356774 * pow(x, 3) +
90
- 0.797885 * x), 2) + 0.5;
91
- dy(arma::find(x < -10)).fill(0); // catch overflows
95
+ typedef typename InputVecType::elem_type ElemType;
96
+
97
+ dy = ElemType(0.5) * tanh(ElemType(0.0356774) * pow(x, ElemType(3)) +
98
+ ElemType(0.797885) * x) + (ElemType(0.0535161) * pow(x, ElemType(3)) +
99
+ ElemType(0.398942) * x) %
100
+ pow(1 / cosh(ElemType(0.0356774) * pow(x, ElemType(3)) +
101
+ ElemType(0.797885) * x), 2) + ElemType(0.5);
102
+ dy(find(x < -10)).fill(0); // catch overflows
92
103
  }
93
104
  }; // class GELUFunction
94
105
 
@@ -39,9 +39,10 @@ class HardSigmoidFunction
39
39
  * @param x Input data.
40
40
  * @return f(x).
41
41
  */
42
- static double Fn(const double x)
42
+ template<typename ElemType>
43
+ static ElemType Fn(const ElemType x)
43
44
  {
44
- return std::min(1.0, std::max(0.0, 0.2 * x + 0.5));
45
+ return std::min(ElemType(1), std::max(ElemType(0), x / 5 + ElemType(0.5)));
45
46
  }
46
47
 
47
48
  /**
@@ -67,13 +68,14 @@ class HardSigmoidFunction
67
68
  * @param y Result of Fn(x).
68
69
  * @return f'(x)
69
70
  */
70
- static double Deriv(const double /* x */, const double y)
71
+ template<typename ElemType>
72
+ static ElemType Deriv(const ElemType /* x */, const ElemType y)
71
73
  {
72
- if (y == 0.0 || y == 1.0)
74
+ if (y == 0 || y == 1)
73
75
  {
74
- return 0.0;
76
+ return 0;
75
77
  }
76
- return 0.2;
78
+ return ElemType(0.2);
77
79
  }
78
80
 
79
81
  /**
@@ -52,7 +52,8 @@ class HardSwishFunction
52
52
  * @param x Input data.
53
53
  * @return f(x).
54
54
  */
55
- static double Fn(const double x)
55
+ template<typename ElemType>
56
+ static ElemType Fn(const ElemType x)
56
57
  {
57
58
  if (x <= -3)
58
59
  return 0;
@@ -68,7 +69,7 @@ class HardSwishFunction
68
69
  * @param x Input data.
69
70
  * @param y The resulting output activation.
70
71
  */
71
- template <typename InputVecType, typename OutputVecType>
72
+ template<typename InputVecType, typename OutputVecType>
72
73
  static void Fn(const InputVecType &x, OutputVecType &y)
73
74
  {
74
75
  y.set_size(size(x));
@@ -85,14 +86,15 @@ class HardSwishFunction
85
86
  * @param * (y) Result of Fn(x).
86
87
  * @return f'(x).
87
88
  */
88
- static double Deriv(const double x, const double /* y */)
89
+ template<typename ElemType>
90
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
89
91
  {
90
92
  if (x <= -3)
91
93
  return 0;
92
94
  else if (x >= 3)
93
95
  return 1;
94
96
 
95
- return (2 * x + 3.0) / 6.0;
97
+ return (2 * x + 3) / 6;
96
98
  }
97
99
 
98
100
  /**
@@ -56,12 +56,13 @@ class HyperSinhFunction
56
56
  * @param x Input data.
57
57
  * @return f(x).
58
58
  */
59
- static double Fn(const double x)
59
+ template<typename ElemType>
60
+ static ElemType Fn(const ElemType x)
60
61
  {
61
62
  if (x > 0)
62
- return (std::sinh(x) / 3.0);
63
+ return (std::sinh(x) / 3);
63
64
  else
64
- return (std::pow(x, 3.0) / 4.0);
65
+ return (std::pow(x, ElemType(3)) / 4);
65
66
  }
66
67
 
67
68
  /**
@@ -94,12 +95,13 @@ class HyperSinhFunction
94
95
  * @param y Input activation.
95
96
  * @return f'(x)
96
97
  */
97
- static double Deriv(const double /* x */, const double y)
98
+ template<typename ElemType>
99
+ static ElemType Deriv(const ElemType /* x */, const ElemType y)
98
100
  {
99
101
  if (y > 0)
100
- return (std::pow((1.0 / 9.0) + (y * y), 0.5));
102
+ return (std::pow(ElemType(1.0 / 9.0) + (y * y), ElemType(0.5)));
101
103
  else
102
- return (3.0 * std::pow(std::pow(y, 2) / 4, 1.0 / 3.0));
104
+ return (3 * std::pow(std::pow(y, ElemType(2)) / 4, ElemType(1.0 / 3.0)));
103
105
  }
104
106
 
105
107
  /**
@@ -113,17 +115,20 @@ class HyperSinhFunction
113
115
  const OutputVecType& y,
114
116
  DerivVecType& dy)
115
117
  {
118
+ typedef typename InputVecType::elem_type ElemType;
119
+
116
120
  dy.set_size(size(y));
117
121
  #pragma omp for
118
122
  for (size_t i = 0; i < y.n_elem; ++i)
119
123
  {
120
124
  if (y(i) > 0)
121
125
  {
122
- dy(i) = (std::pow((1.0 / 9.0) + (y(i) * y(i)), 0.5));
126
+ dy(i) = (std::pow(ElemType(1.0 / 9.0) + (y(i) * y(i)), ElemType(0.5)));
123
127
  }
124
128
  else
125
129
  {
126
- dy(i) = (3.0 * std::pow(std::pow(y(i), 2) / 4, 1.0 / 3.0));
130
+ dy(i) = (3 * std::pow(std::pow(y(i), ElemType(2)) / 4,
131
+ ElemType(1.0 / 3.0)));
127
132
  }
128
133
  }
129
134
  }
@@ -33,7 +33,8 @@ class IdentityFunction
33
33
  * @param x Input data.
34
34
  * @return f(x).
35
35
  */
36
- static double Fn(const double x)
36
+ template<typename ElemType>
37
+ static ElemType Fn(const ElemType x)
37
38
  {
38
39
  return x;
39
40
  }
@@ -59,9 +60,10 @@ class IdentityFunction
59
60
  * @param * (y) Result of Fn(x).
60
61
  * @return f'(x)
61
62
  */
62
- static double Deriv(const double /* x */, const double /* y */)
63
+ template<typename ElemType>
64
+ static ElemType Deriv(const ElemType /* x */, const ElemType /* y */)
63
65
  {
64
- return 1.0;
66
+ return 1;
65
67
  }
66
68
 
67
69
  /**
@@ -76,7 +78,7 @@ class IdentityFunction
76
78
  const OutputVecType& /* y */,
77
79
  DerivVecType& dy)
78
80
  {
79
- dy.ones(arma::size(x));
81
+ dy.ones(size(x));
80
82
  }
81
83
 
82
84
  /**
@@ -33,9 +33,10 @@ class InvQuadFunction
33
33
  * @param x Input data.
34
34
  * @return f(x).
35
35
  */
36
- static double Fn(const double x)
36
+ template<typename ElemType>
37
+ static ElemType Fn(const ElemType x)
37
38
  {
38
- return 1 / ( 1 + x * x);
39
+ return 1 / (1 + x * x);
39
40
  }
40
41
 
41
42
  /**
@@ -47,7 +48,7 @@ class InvQuadFunction
47
48
  template<typename InputVecType, typename OutputVecType>
48
49
  static void Fn(const InputVecType& x, OutputVecType& y)
49
50
  {
50
- y = 1 / (1 + pow(x, 2));
51
+ y = 1 / (1 + square(x));
51
52
  }
52
53
 
53
54
  /**
@@ -57,9 +58,10 @@ class InvQuadFunction
57
58
  * @param y Result of Fn(x).
58
59
  * @return f'(x)
59
60
  */
60
- static double Deriv(const double x, const double /* y */)
61
+ template<typename ElemType>
62
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
61
63
  {
62
- return -2 * x / std::pow(1 + std::pow(x, 2), 2);
64
+ return -2 * x / std::pow(1 + std::pow(x, ElemType(2)), ElemType(2));
63
65
  }
64
66
 
65
67
  /**
@@ -74,7 +76,7 @@ class InvQuadFunction
74
76
  const OutputVecType& /* y */,
75
77
  DerivVecType &dy)
76
78
  {
77
- dy = - 2 * x / pow(1 + pow(x, 2), 2);
79
+ dy = -2 * x / square(1 + square(x));
78
80
  }
79
81
  }; // class InvQuadFunction
80
82
 
@@ -47,7 +47,8 @@ class LiSHTFunction
47
47
  * @param x Input data.
48
48
  * @return f(x).
49
49
  */
50
- static double Fn(const double x)
50
+ template<typename ElemType>
51
+ static ElemType Fn(const ElemType x)
51
52
  {
52
53
  return x * std::tanh(x);
53
54
  }
@@ -61,7 +62,7 @@ class LiSHTFunction
61
62
  template <typename InputVecType, typename OutputVecType>
62
63
  static void Fn(const InputVecType &x, OutputVecType &y)
63
64
  {
64
- y = x % arma::tanh(x);
65
+ y = x % tanh(x);
65
66
  }
66
67
 
67
68
  /**
@@ -71,9 +72,10 @@ class LiSHTFunction
71
72
  * @param y Result of Fn(x).
72
73
  * @return f'(x)
73
74
  */
74
- static double Deriv(const double x, const double /* y */)
75
+ template<typename ElemType>
76
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
75
77
  {
76
- return std::tanh(x) + x * (1 - std::pow(std::tanh(x), 2));
78
+ return std::tanh(x) + x * (1 - std::pow(std::tanh(x), ElemType(2)));
77
79
  }
78
80
 
79
81
  /**
@@ -88,7 +90,7 @@ class LiSHTFunction
88
90
  const OutputVecType& /* y */,
89
91
  DerivVecType& dy)
90
92
  {
91
- dy = arma::tanh(x) + x % (1 - pow(arma::tanh(x), 2));
93
+ dy = tanh(x) + x % (1 - square(tanh(x)));
92
94
  }
93
95
  }; // class LishtFunction
94
96
 
@@ -34,18 +34,18 @@ class LogisticFunction
34
34
  * @param x Input data.
35
35
  * @return f(x).
36
36
  */
37
- template<typename eT>
38
- static double Fn(const eT x)
37
+ template<typename ElemType>
38
+ static ElemType Fn(const ElemType x)
39
39
  {
40
- if (x < arma::Datum<eT>::log_max)
40
+ if (x < arma::Datum<ElemType>::log_max)
41
41
  {
42
- if (x > -arma::Datum<eT>::log_max)
43
- return 1.0 / (1.0 + std::exp(-x));
42
+ if (x > -arma::Datum<ElemType>::log_max)
43
+ return 1 / (1 + std::exp(-x));
44
44
 
45
- return 0.0;
45
+ return 0;
46
46
  }
47
47
 
48
- return 1.0;
48
+ return 1;
49
49
  }
50
50
 
51
51
  /**
@@ -57,7 +57,7 @@ class LogisticFunction
57
57
  template<typename InputVecType, typename OutputVecType>
58
58
  static void Fn(const InputVecType& x, OutputVecType& y)
59
59
  {
60
- y = (1.0 / (1 + exp(-x)));
60
+ y = (1 / (1 + exp(-x)));
61
61
  }
62
62
 
63
63
  /**
@@ -67,9 +67,10 @@ class LogisticFunction
67
67
  * @param y Result of Fn(x).
68
68
  * @return f'(x)
69
69
  */
70
- static double Deriv(const double /* x */, const double y)
70
+ template<typename ElemType>
71
+ static ElemType Deriv(const ElemType /* x */, const ElemType y)
71
72
  {
72
- return y * (1.0 - y);
73
+ return y * (1 - y);
73
74
  }
74
75
 
75
76
  /**
@@ -84,7 +85,7 @@ class LogisticFunction
84
85
  const OutputVecType& y,
85
86
  DerivVecType& dy)
86
87
  {
87
- dy = y % (1.0 - y);
88
+ dy = y % (1 - y);
88
89
  }
89
90
 
90
91
  /**
@@ -93,7 +94,8 @@ class LogisticFunction
93
94
  * @param y Input data.
94
95
  * @return f^{-1}(y)
95
96
  */
96
- static double Inv(const double y)
97
+ template<typename ElemType>
98
+ static ElemType Inv(const ElemType y)
97
99
  {
98
100
  return arma::trunc_log(y / (1 - y));
99
101
  }
@@ -45,7 +45,8 @@ class MishFunction
45
45
  * @param x Input data.
46
46
  * @return f(x).
47
47
  */
48
- static double Fn(const double x)
48
+ template<typename ElemType>
49
+ static ElemType Fn(const ElemType x)
49
50
  {
50
51
  return x * (std::exp(2 * x) + 2 * std::exp(x)) /
51
52
  (2 + 2 * std::exp(x) + std::exp(2 * x));
@@ -57,7 +58,7 @@ class MishFunction
57
58
  * @param x Input data.
58
59
  * @param y The resulting output activation.
59
60
  */
60
- template <typename InputVecType, typename OutputVecType>
61
+ template<typename InputVecType, typename OutputVecType>
61
62
  static void Fn(const InputVecType &x, OutputVecType &y)
62
63
  {
63
64
  y = x % (exp(2 * x) + 2 * exp(x)) / (2 + 2 * exp(x) + exp(2 * x));
@@ -70,11 +71,12 @@ class MishFunction
70
71
  * @param y Result of Fn(x).
71
72
  * @return f'(x)
72
73
  */
73
- static double Deriv(const double x, const double /* y */)
74
+ template<typename ElemType>
75
+ static ElemType Deriv(const ElemType x, const ElemType /* y */)
74
76
  {
75
77
  return std::exp(x) * (4 * (x + 1) + std::exp(x) * (4 * x + 6) +
76
78
  4 * std::exp(2 * x) + std::exp(3 * x)) /
77
- std::pow(std::exp(2 * x) + 2 * std::exp(x) + 2, 2);
79
+ std::pow(std::exp(2 * x) + 2 * std::exp(x) + 2, ElemType(2));
78
80
  }
79
81
 
80
82
  /**
@@ -91,7 +93,7 @@ class MishFunction
91
93
  {
92
94
  dy = exp(x) % (4 * (x + 1) + exp(x) % (4 * x + 6) +
93
95
  4 * exp(2 * x) + exp(3 * x)) /
94
- pow(exp(2 * x) + 2 * exp(x) + 2, 2);
96
+ square(exp(2 * x) + 2 * exp(x) + 2);
95
97
  }
96
98
  }; // class MishFunction
97
99