mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -38,11 +38,14 @@ template<
38
38
  typename MatType = arma::mat,
39
39
  typename RegularizerType = NoRegularizer
40
40
  >
41
- class LinearType : public Layer<MatType>
41
+ class Linear : public Layer<MatType>
42
42
  {
43
43
  public:
44
- //! Create the Linear object.
45
- LinearType();
44
+ // Convenience typedef to access the element type of the weights and data.
45
+ using ElemType = typename MatType::elem_type;
46
+
47
+ // Create the Linear object.
48
+ Linear();
46
49
 
47
50
  /**
48
51
  * Create the Linear layer object with the specified number of output
@@ -52,25 +55,25 @@ class LinearType : public Layer<MatType>
52
55
  * @param regularizer The regularizer to use, optional (default: no
53
56
  * regularizer).
54
57
  */
55
- LinearType(const size_t outSize,
56
- RegularizerType regularizer = RegularizerType());
58
+ Linear(const size_t outSize,
59
+ RegularizerType regularizer = RegularizerType());
57
60
 
58
- virtual ~LinearType() { }
61
+ virtual ~Linear() { }
59
62
 
60
- //! Clone the LinearType object. This handles polymorphism correctly.
61
- LinearType* Clone() const { return new LinearType(*this); }
63
+ //! Clone the Linear object. This handles polymorphism correctly.
64
+ Linear* Clone() const { return new Linear(*this); }
62
65
 
63
66
  //! Copy the other Linear layer (but not weights).
64
- LinearType(const LinearType& layer);
67
+ Linear(const Linear& layer);
65
68
 
66
69
  //! Take ownership of the members of the other Linear layer (but not weights).
67
- LinearType(LinearType&& layer);
70
+ Linear(Linear&& layer);
68
71
 
69
72
  //! Copy the other Linear layer (but not weights).
70
- LinearType& operator=(const LinearType& layer);
73
+ Linear& operator=(const Linear& layer);
71
74
 
72
75
  //! Take ownership of the members of the other Linear layer (but not weights).
73
- LinearType& operator=(LinearType&& layer);
76
+ Linear& operator=(Linear&& layer);
74
77
 
75
78
  /**
76
79
  * Reset the layer parameter (weights and bias). The method is called to
@@ -162,12 +165,7 @@ class LinearType : public Layer<MatType>
162
165
 
163
166
  //! Locally-stored regularizer object.
164
167
  RegularizerType regularizer;
165
- }; // class LinearType
166
-
167
- // Convenience typedefs.
168
-
169
- // Standard Linear layer using no regularization.
170
- using Linear = LinearType<arma::mat, NoRegularizer>;
168
+ }; // class Linear
171
169
 
172
170
  } // namespace mlpack
173
171
 
@@ -34,12 +34,15 @@ template<
34
34
  typename MatType = arma::mat,
35
35
  typename RegularizerType = NoRegularizer
36
36
  >
37
- class Linear3DType : public Layer<MatType>
37
+ class Linear3D : public Layer<MatType>
38
38
  {
39
39
  public:
40
+ // Convenience typedefs.
41
+ using ElemType = typename MatType::elem_type;
40
42
  using CubeType = typename GetCubeType<MatType>::type;
41
- //! Create the Linear3D object.
42
- Linear3DType();
43
+
44
+ // Create the Linear3D object.
45
+ Linear3D();
43
46
 
44
47
  /**
45
48
  * Create the Linear3D layer object using the specified number of output
@@ -48,23 +51,23 @@ class Linear3DType : public Layer<MatType>
48
51
  * @param outSize The number of output units.
49
52
  * @param regularizer The regularizer to use, optional.
50
53
  */
51
- Linear3DType(const size_t outSize,
54
+ Linear3D(const size_t outSize,
52
55
  RegularizerType regularizer = RegularizerType());
53
56
 
54
- //! Clone the Linear3DType object. This handles polymorphism correctly.
55
- Linear3DType* Clone() const { return new Linear3DType(*this); }
57
+ //! Clone the Linear3D object. This handles polymorphism correctly.
58
+ Linear3D* Clone() const { return new Linear3D(*this); }
56
59
 
57
60
  // Virtual destructor.
58
- virtual ~Linear3DType() { }
61
+ virtual ~Linear3D() { }
59
62
 
60
- //! Copy the given Linear3DType (but not weights).
61
- Linear3DType(const Linear3DType& other);
62
- //! Take ownership of the given Linear3DType (but not weights).
63
- Linear3DType(Linear3DType&& other);
64
- //! Copy the given Linear3DType (but not weights).
65
- Linear3DType& operator=(const Linear3DType& other);
66
- //! Take ownership of the given Linear3DType (but not weights).
67
- Linear3DType& operator=(Linear3DType&& other);
63
+ //! Copy the given Linear3D (but not weights).
64
+ Linear3D(const Linear3D& other);
65
+ //! Take ownership of the given Linear3D (but not weights).
66
+ Linear3D(Linear3D&& other);
67
+ //! Copy the given Linear3D (but not weights).
68
+ Linear3D& operator=(const Linear3D& other);
69
+ //! Take ownership of the given Linear3D (but not weights).
70
+ Linear3D& operator=(Linear3D&& other);
68
71
 
69
72
  /*
70
73
  * Reset the layer parameter.
@@ -150,9 +153,6 @@ class Linear3DType : public Layer<MatType>
150
153
  RegularizerType regularizer;
151
154
  }; // class Linear
152
155
 
153
- // Standard Linear3D layer.
154
- using Linear3D = Linear3DType<arma::mat, NoRegularizer>;
155
-
156
156
  } // namespace mlpack
157
157
 
158
158
  // Include implementation.
@@ -18,7 +18,7 @@
18
18
  namespace mlpack {
19
19
 
20
20
  template<typename MatType, typename RegularizerType>
21
- Linear3DType<MatType, RegularizerType>::Linear3DType() :
21
+ Linear3D<MatType, RegularizerType>::Linear3D() :
22
22
  Layer<MatType>(),
23
23
  outSize(0)
24
24
  {
@@ -26,7 +26,7 @@ Linear3DType<MatType, RegularizerType>::Linear3DType() :
26
26
  }
27
27
 
28
28
  template<typename MatType, typename RegularizerType>
29
- Linear3DType<MatType, RegularizerType>::Linear3DType(
29
+ Linear3D<MatType, RegularizerType>::Linear3D(
30
30
  const size_t outSize,
31
31
  RegularizerType regularizer) :
32
32
  Layer<MatType>(),
@@ -35,8 +35,8 @@ Linear3DType<MatType, RegularizerType>::Linear3DType(
35
35
  { }
36
36
 
37
37
  template<typename MatType, typename RegularizerType>
38
- Linear3DType<MatType, RegularizerType>::Linear3DType(
39
- const Linear3DType& other) :
38
+ Linear3D<MatType, RegularizerType>::Linear3D(
39
+ const Linear3D& other) :
40
40
  Layer<MatType>(other),
41
41
  outSize(other.outSize),
42
42
  regularizer(other.regularizer)
@@ -45,8 +45,8 @@ Linear3DType<MatType, RegularizerType>::Linear3DType(
45
45
  }
46
46
 
47
47
  template<typename MatType, typename RegularizerType>
48
- Linear3DType<MatType, RegularizerType>::Linear3DType(
49
- Linear3DType&& other) :
48
+ Linear3D<MatType, RegularizerType>::Linear3D(
49
+ Linear3D&& other) :
50
50
  Layer<MatType>(std::move(other)),
51
51
  outSize(std::move(other.outSize)),
52
52
  regularizer(std::move(other.regularizer))
@@ -55,9 +55,9 @@ Linear3DType<MatType, RegularizerType>::Linear3DType(
55
55
  }
56
56
 
57
57
  template<typename MatType, typename RegularizerType>
58
- Linear3DType<MatType, RegularizerType>&
59
- Linear3DType<MatType, RegularizerType>::operator=(
60
- const Linear3DType& other)
58
+ Linear3D<MatType, RegularizerType>&
59
+ Linear3D<MatType, RegularizerType>::operator=(
60
+ const Linear3D& other)
61
61
  {
62
62
  if (&other != this)
63
63
  {
@@ -70,9 +70,9 @@ Linear3DType<MatType, RegularizerType>::operator=(
70
70
  }
71
71
 
72
72
  template<typename MatType, typename RegularizerType>
73
- Linear3DType<MatType, RegularizerType>&
74
- Linear3DType<MatType, RegularizerType>::operator=(
75
- Linear3DType&& other)
73
+ Linear3D<MatType, RegularizerType>&
74
+ Linear3D<MatType, RegularizerType>::operator=(
75
+ Linear3D&& other)
76
76
  {
77
77
  if (&other != this)
78
78
  {
@@ -85,7 +85,7 @@ Linear3DType<MatType, RegularizerType>::operator=(
85
85
  }
86
86
 
87
87
  template<typename MatType, typename RegularizerType>
88
- void Linear3DType<MatType, RegularizerType>::SetWeights(
88
+ void Linear3D<MatType, RegularizerType>::SetWeights(
89
89
  const MatType& weightsIn)
90
90
  {
91
91
  MakeAlias(weights, weightsIn, outSize * this->inputDimensions[0] + outSize,
@@ -95,7 +95,7 @@ void Linear3DType<MatType, RegularizerType>::SetWeights(
95
95
  }
96
96
 
97
97
  template<typename MatType, typename RegularizerType>
98
- void Linear3DType<MatType, RegularizerType>::Forward(
98
+ void Linear3D<MatType, RegularizerType>::Forward(
99
99
  const MatType& input, MatType& output)
100
100
  {
101
101
  const size_t nPoints = input.n_rows / this->inputDimensions[0];
@@ -116,7 +116,7 @@ void Linear3DType<MatType, RegularizerType>::Forward(
116
116
  }
117
117
 
118
118
  template<typename MatType, typename RegularizerType>
119
- void Linear3DType<MatType, RegularizerType>::Backward(
119
+ void Linear3D<MatType, RegularizerType>::Backward(
120
120
  const MatType& /* input */,
121
121
  const MatType& /* output */,
122
122
  const MatType& gy,
@@ -144,7 +144,7 @@ void Linear3DType<MatType, RegularizerType>::Backward(
144
144
  }
145
145
 
146
146
  template<typename MatType, typename RegularizerType>
147
- void Linear3DType<MatType, RegularizerType>::Gradient(
147
+ void Linear3D<MatType, RegularizerType>::Gradient(
148
148
  const MatType& input,
149
149
  const MatType& error,
150
150
  MatType& gradient)
@@ -178,7 +178,7 @@ void Linear3DType<MatType, RegularizerType>::Gradient(
178
178
  }
179
179
 
180
180
  template<typename MatType, typename RegularizerType>
181
- void Linear3DType<
181
+ void Linear3D<
182
182
  MatType, RegularizerType
183
183
  >::ComputeOutputDimensions()
184
184
  {
@@ -191,7 +191,7 @@ void Linear3DType<
191
191
 
192
192
  template<typename MatType, typename RegularizerType>
193
193
  template<typename Archive>
194
- void Linear3DType<MatType, RegularizerType>::serialize(
194
+ void Linear3D<MatType, RegularizerType>::serialize(
195
195
  Archive& ar, const uint32_t /* version */)
196
196
  {
197
197
  ar(cereal::base_class<Layer<MatType>>(this));
@@ -19,7 +19,7 @@
19
19
  namespace mlpack {
20
20
 
21
21
  template<typename MatType, typename RegularizerType>
22
- LinearType<MatType, RegularizerType>::LinearType() :
22
+ Linear<MatType, RegularizerType>::Linear() :
23
23
  Layer<MatType>(),
24
24
  inSize(0),
25
25
  outSize(0)
@@ -28,7 +28,7 @@ LinearType<MatType, RegularizerType>::LinearType() :
28
28
  }
29
29
 
30
30
  template<typename MatType, typename RegularizerType>
31
- LinearType<MatType, RegularizerType>::LinearType(
31
+ Linear<MatType, RegularizerType>::Linear(
32
32
  const size_t outSize,
33
33
  RegularizerType regularizer) :
34
34
  Layer<MatType>(),
@@ -41,7 +41,7 @@ LinearType<MatType, RegularizerType>::LinearType(
41
41
 
42
42
  // Copy constructor.
43
43
  template<typename MatType, typename RegularizerType>
44
- LinearType<MatType, RegularizerType>::LinearType(const LinearType& layer) :
44
+ Linear<MatType, RegularizerType>::Linear(const Linear& layer) :
45
45
  Layer<MatType>(layer),
46
46
  inSize(layer.inSize),
47
47
  outSize(layer.outSize),
@@ -52,7 +52,7 @@ LinearType<MatType, RegularizerType>::LinearType(const LinearType& layer) :
52
52
 
53
53
  // Move constructor.
54
54
  template<typename MatType, typename RegularizerType>
55
- LinearType<MatType, RegularizerType>::LinearType(LinearType&& layer) :
55
+ Linear<MatType, RegularizerType>::Linear(Linear&& layer) :
56
56
  Layer<MatType>(std::move(layer)),
57
57
  inSize(std::move(layer.inSize)),
58
58
  outSize(std::move(layer.outSize)),
@@ -64,8 +64,8 @@ LinearType<MatType, RegularizerType>::LinearType(LinearType&& layer) :
64
64
  }
65
65
 
66
66
  template<typename MatType, typename RegularizerType>
67
- LinearType<MatType, RegularizerType>&
68
- LinearType<MatType, RegularizerType>::operator=(const LinearType& layer)
67
+ Linear<MatType, RegularizerType>&
68
+ Linear<MatType, RegularizerType>::operator=(const Linear& layer)
69
69
  {
70
70
  if (&layer != this)
71
71
  {
@@ -79,9 +79,9 @@ LinearType<MatType, RegularizerType>::operator=(const LinearType& layer)
79
79
  }
80
80
 
81
81
  template<typename MatType, typename RegularizerType>
82
- LinearType<MatType, RegularizerType>&
83
- LinearType<MatType, RegularizerType>::operator=(
84
- LinearType&& layer)
82
+ Linear<MatType, RegularizerType>&
83
+ Linear<MatType, RegularizerType>::operator=(
84
+ Linear&& layer)
85
85
  {
86
86
  if (&layer != this)
87
87
  {
@@ -99,7 +99,7 @@ LinearType<MatType, RegularizerType>::operator=(
99
99
  }
100
100
 
101
101
  template<typename MatType, typename RegularizerType>
102
- void LinearType<MatType, RegularizerType>::SetWeights(const MatType& weightsIn)
102
+ void Linear<MatType, RegularizerType>::SetWeights(const MatType& weightsIn)
103
103
  {
104
104
  MakeAlias(weights, weightsIn, outSize * inSize + outSize, 1);
105
105
  MakeAlias(weight, weightsIn, outSize, inSize);
@@ -107,7 +107,7 @@ void LinearType<MatType, RegularizerType>::SetWeights(const MatType& weightsIn)
107
107
  }
108
108
 
109
109
  template<typename MatType, typename RegularizerType>
110
- void LinearType<MatType, RegularizerType>::Forward(
110
+ void Linear<MatType, RegularizerType>::Forward(
111
111
  const MatType& input, MatType& output)
112
112
  {
113
113
  output = weight * input;
@@ -118,7 +118,7 @@ void LinearType<MatType, RegularizerType>::Forward(
118
118
  }
119
119
 
120
120
  template<typename MatType, typename RegularizerType>
121
- void LinearType<MatType, RegularizerType>::Backward(
121
+ void Linear<MatType, RegularizerType>::Backward(
122
122
  const MatType& /* input */,
123
123
  const MatType& /* output */,
124
124
  const MatType& gy,
@@ -128,7 +128,7 @@ void LinearType<MatType, RegularizerType>::Backward(
128
128
  }
129
129
 
130
130
  template<typename MatType, typename RegularizerType>
131
- void LinearType<MatType, RegularizerType>::Gradient(
131
+ void Linear<MatType, RegularizerType>::Gradient(
132
132
  const MatType& input,
133
133
  const MatType& error,
134
134
  MatType& gradient)
@@ -139,7 +139,7 @@ void LinearType<MatType, RegularizerType>::Gradient(
139
139
  }
140
140
 
141
141
  template<typename MatType, typename RegularizerType>
142
- void LinearType<MatType, RegularizerType>::ComputeOutputDimensions()
142
+ void Linear<MatType, RegularizerType>::ComputeOutputDimensions()
143
143
  {
144
144
  inSize = this->inputDimensions[0];
145
145
  for (size_t i = 1; i < this->inputDimensions.size(); ++i)
@@ -153,7 +153,7 @@ void LinearType<MatType, RegularizerType>::ComputeOutputDimensions()
153
153
 
154
154
  template<typename MatType, typename RegularizerType>
155
155
  template<typename Archive>
156
- void LinearType<MatType, RegularizerType>::serialize(
156
+ void Linear<MatType, RegularizerType>::serialize(
157
157
  Archive& ar, const uint32_t /* version */)
158
158
  {
159
159
  ar(cereal::base_class<Layer<MatType>>(this));
@@ -33,11 +33,14 @@ template<
33
33
  typename MatType = arma::mat,
34
34
  typename RegularizerType = NoRegularizer
35
35
  >
36
- class LinearNoBiasType : public Layer<MatType>
36
+ class LinearNoBias : public Layer<MatType>
37
37
  {
38
38
  public:
39
- //! Create the LinearNoBias object.
40
- LinearNoBiasType();
39
+ // Convenience typedef to access the element type of the weights and data.
40
+ using ElemType = typename MatType::elem_type;
41
+
42
+ // Create the LinearNoBias object.
43
+ LinearNoBias();
41
44
 
42
45
  /**
43
46
  * Create the LinearNoBias object using the specified number of units.
@@ -45,29 +48,29 @@ class LinearNoBiasType : public Layer<MatType>
45
48
  * @param outSize The number of output units.
46
49
  * @param regularizer The regularizer to use, optional.
47
50
  */
48
- LinearNoBiasType(const size_t outSize,
51
+ LinearNoBias(const size_t outSize,
49
52
  RegularizerType regularizer = RegularizerType());
50
53
 
51
- //! Clone the LinearNoBiasType object. This handles polymorphism correctly.
52
- LinearNoBiasType* Clone() const { return new LinearNoBiasType(*this); }
54
+ //! Clone the LinearNoBias object. This handles polymorphism correctly.
55
+ LinearNoBias* Clone() const { return new LinearNoBias(*this); }
53
56
 
54
57
  //! Reset the layer parameter.
55
58
  void SetWeights(const MatType& weightsIn);
56
59
 
57
60
  //! Copy constructor.
58
- LinearNoBiasType(const LinearNoBiasType& layer);
61
+ LinearNoBias(const LinearNoBias& layer);
59
62
 
60
63
  //! Move constructor.
61
- LinearNoBiasType(LinearNoBiasType&&);
64
+ LinearNoBias(LinearNoBias&&);
62
65
 
63
66
  //! Copy assignment operator.
64
- LinearNoBiasType& operator=(const LinearNoBiasType& layer);
67
+ LinearNoBias& operator=(const LinearNoBias& layer);
65
68
 
66
69
  //! Move assignment operator.
67
- LinearNoBiasType& operator=(LinearNoBiasType&& layer);
70
+ LinearNoBias& operator=(LinearNoBias&& layer);
68
71
 
69
72
  //! Virtual destructor.
70
- virtual ~LinearNoBiasType() { }
73
+ virtual ~LinearNoBias() { }
71
74
 
72
75
  /**
73
76
  * Ordinary feed forward pass of a neural network, evaluating the function
@@ -131,12 +134,7 @@ class LinearNoBiasType : public Layer<MatType>
131
134
 
132
135
  //! Locally-stored regularizer object.
133
136
  RegularizerType regularizer;
134
- }; // class LinearNoBiasType
135
-
136
- // Convenience typedefs.
137
-
138
- // Standard Linear without bias layer using no regularization.
139
- using LinearNoBias = LinearNoBiasType<arma::mat, NoRegularizer>;
137
+ }; // class LinearNoBias
140
138
 
141
139
  } // namespace mlpack
142
140
 
@@ -19,7 +19,7 @@
19
19
  namespace mlpack {
20
20
 
21
21
  template<typename MatType, typename RegularizerType>
22
- LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType() :
22
+ LinearNoBias<MatType, RegularizerType>::LinearNoBias() :
23
23
  Layer<MatType>(),
24
24
  inSize(0),
25
25
  outSize(0)
@@ -28,7 +28,7 @@ LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType() :
28
28
  }
29
29
 
30
30
  template<typename MatType, typename RegularizerType>
31
- LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType(
31
+ LinearNoBias<MatType, RegularizerType>::LinearNoBias(
32
32
  const size_t outSize,
33
33
  RegularizerType regularizer) :
34
34
  Layer<MatType>(),
@@ -40,8 +40,8 @@ LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType(
40
40
  }
41
41
 
42
42
  template<typename MatType, typename RegularizerType>
43
- LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType(
44
- const LinearNoBiasType& layer) :
43
+ LinearNoBias<MatType, RegularizerType>::LinearNoBias(
44
+ const LinearNoBias& layer) :
45
45
  Layer<MatType>(layer),
46
46
  inSize(layer.inSize),
47
47
  outSize(layer.outSize),
@@ -51,11 +51,11 @@ LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType(
51
51
  }
52
52
 
53
53
  template<typename MatType, typename RegularizerType>
54
- LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType(
55
- LinearNoBiasType&& layer) :
54
+ LinearNoBias<MatType, RegularizerType>::LinearNoBias(
55
+ LinearNoBias&& layer) :
56
56
  Layer<MatType>(std::move(layer)),
57
- inSize(0),
58
- outSize(0),
57
+ inSize(std::move(layer.inSize)),
58
+ outSize(std::move(layer.outSize)),
59
59
  regularizer(std::move(layer.regularizer))
60
60
  {
61
61
  // Reset parameters of other layer.
@@ -64,9 +64,9 @@ LinearNoBiasType<MatType, RegularizerType>::LinearNoBiasType(
64
64
  }
65
65
 
66
66
  template<typename MatType, typename RegularizerType>
67
- LinearNoBiasType<MatType, RegularizerType>&
68
- LinearNoBiasType<MatType, RegularizerType>::operator=(
69
- const LinearNoBiasType& layer)
67
+ LinearNoBias<MatType, RegularizerType>&
68
+ LinearNoBias<MatType, RegularizerType>::operator=(
69
+ const LinearNoBias& layer)
70
70
  {
71
71
  if (this != &layer)
72
72
  {
@@ -80,9 +80,9 @@ LinearNoBiasType<MatType, RegularizerType>::operator=(
80
80
  }
81
81
 
82
82
  template<typename MatType, typename RegularizerType>
83
- LinearNoBiasType<MatType, RegularizerType>&
84
- LinearNoBiasType<MatType, RegularizerType>::operator=(
85
- LinearNoBiasType&& layer)
83
+ LinearNoBias<MatType, RegularizerType>&
84
+ LinearNoBias<MatType, RegularizerType>::operator=(
85
+ LinearNoBias&& layer)
86
86
  {
87
87
  if (this != &layer)
88
88
  {
@@ -100,21 +100,21 @@ LinearNoBiasType<MatType, RegularizerType>::operator=(
100
100
  }
101
101
 
102
102
  template<typename MatType, typename RegularizerType>
103
- void LinearNoBiasType<MatType, RegularizerType>::SetWeights(
103
+ void LinearNoBias<MatType, RegularizerType>::SetWeights(
104
104
  const MatType& weights)
105
105
  {
106
106
  MakeAlias(weight, weights, outSize, inSize);
107
107
  }
108
108
 
109
109
  template<typename MatType, typename RegularizerType>
110
- void LinearNoBiasType<MatType, RegularizerType>::Forward(
110
+ void LinearNoBias<MatType, RegularizerType>::Forward(
111
111
  const MatType& input, MatType& output)
112
112
  {
113
113
  output = weight * input;
114
114
  }
115
115
 
116
116
  template<typename MatType, typename RegularizerType>
117
- void LinearNoBiasType<MatType, RegularizerType>::Backward(
117
+ void LinearNoBias<MatType, RegularizerType>::Backward(
118
118
  const MatType& /* input */,
119
119
  const MatType& /* output */,
120
120
  const MatType& gy,
@@ -124,7 +124,7 @@ void LinearNoBiasType<MatType, RegularizerType>::Backward(
124
124
  }
125
125
 
126
126
  template<typename MatType, typename RegularizerType>
127
- void LinearNoBiasType<MatType, RegularizerType>::Gradient(
127
+ void LinearNoBias<MatType, RegularizerType>::Gradient(
128
128
  const MatType& input,
129
129
  const MatType& error,
130
130
  MatType& gradient)
@@ -134,7 +134,7 @@ void LinearNoBiasType<MatType, RegularizerType>::Gradient(
134
134
  }
135
135
 
136
136
  template<typename MatType, typename RegularizerType>
137
- void LinearNoBiasType<MatType, RegularizerType>::ComputeOutputDimensions()
137
+ void LinearNoBias<MatType, RegularizerType>::ComputeOutputDimensions()
138
138
  {
139
139
  inSize = this->inputDimensions[0];
140
140
  for (size_t i = 1; i < this->inputDimensions.size(); ++i)
@@ -148,7 +148,7 @@ void LinearNoBiasType<MatType, RegularizerType>::ComputeOutputDimensions()
148
148
 
149
149
  template<typename MatType, typename RegularizerType>
150
150
  template<typename Archive>
151
- void LinearNoBiasType<MatType, RegularizerType>::serialize(
151
+ void LinearNoBias<MatType, RegularizerType>::serialize(
152
152
  Archive& ar, const uint32_t /* version */)
153
153
  {
154
154
  ar(cereal::base_class<Layer<MatType>>(this));
@@ -41,13 +41,16 @@ template<
41
41
  typename MatType = arma::mat,
42
42
  typename RegularizerType = NoRegularizer
43
43
  >
44
- class LinearRecurrentType : public RecurrentLayer<MatType>
44
+ class LinearRecurrent : public RecurrentLayer<MatType>
45
45
  {
46
46
  public:
47
+ // Convenience typedef to access the element type of the weights and data.
48
+ using ElemType = typename MatType::elem_type;
49
+
47
50
  /**
48
51
  * Create the LinearRecurrent layer.
49
52
  */
50
- LinearRecurrentType();
53
+ LinearRecurrent();
51
54
 
52
55
  /**
53
56
  * Create the LinearRecurrent layer object with the specified number of
@@ -57,29 +60,29 @@ class LinearRecurrentType : public RecurrentLayer<MatType>
57
60
  * @param regularizer The regularizer to use; optional (default: no
58
61
  * regularizer)
59
62
  */
60
- LinearRecurrentType(const size_t outSize,
63
+ LinearRecurrent(const size_t outSize,
61
64
  RegularizerType regularizer = RegularizerType());
62
65
 
63
- virtual ~LinearRecurrentType() { }
66
+ virtual ~LinearRecurrent() { }
64
67
 
65
- // Clone the LinearRecurrentType layer. This handles polymorphism correctly.
66
- LinearRecurrentType* Clone() const { return new LinearRecurrentType(*this); }
68
+ // Clone the LinearRecurrent layer. This handles polymorphism correctly.
69
+ LinearRecurrent* Clone() const { return new LinearRecurrent(*this); }
67
70
 
68
71
  // Copy the other linear recurrent layer, including hidden recurrent state
69
72
  // (but not weights).
70
- LinearRecurrentType(const LinearRecurrentType& layer);
73
+ LinearRecurrent(const LinearRecurrent& layer);
71
74
 
72
75
  // Take ownership of the members of the other linear recurrent layer,
73
76
  // including hidden recurrent state (but not weights).
74
- LinearRecurrentType(LinearRecurrentType&& layer);
77
+ LinearRecurrent(LinearRecurrent&& layer);
75
78
 
76
79
  // Copy the other linear recurrent layer, including hidden recurrent state
77
80
  // (but not weights).
78
- LinearRecurrentType& operator=(const LinearRecurrentType& layer);
81
+ LinearRecurrent& operator=(const LinearRecurrent& layer);
79
82
 
80
83
  // Take ownership of the members of the other linear recurrent layer,
81
84
  // including hidden recurrent state (but not weights).
82
- LinearRecurrentType& operator=(LinearRecurrentType&& layer);
85
+ LinearRecurrent& operator=(LinearRecurrent&& layer);
83
86
 
84
87
  /**
85
88
  * Set the parameters of the layer (weights, hidden state weights, and bias).
@@ -140,6 +143,12 @@ class LinearRecurrentType : public RecurrentLayer<MatType>
140
143
  // has been set.
141
144
  void ComputeOutputDimensions();
142
145
 
146
+ // Update the internal aliases of the layer when the step changes.
147
+ void OnStepChanged(const size_t step,
148
+ const size_t batchSize,
149
+ const size_t activeBatchSize,
150
+ const bool backwards);
151
+
143
152
  // Serialize the layer.
144
153
  template<typename Archive>
145
154
  void serialize(Archive& ar, const uint32_t /* version */);
@@ -160,14 +169,16 @@ class LinearRecurrentType : public RecurrentLayer<MatType>
160
169
  // Bias vector.
161
170
  MatType bias;
162
171
 
172
+ // Aliases of the recurrent states.
173
+ MatType currentOutput;
174
+ MatType previousOutput;
175
+ MatType currentGradient;
176
+ MatType previousGradient;
177
+
163
178
  // Locally-stored regularizer object.
164
179
  RegularizerType regularizer;
165
180
  };
166
181
 
167
- // Convenience typedefs.
168
-
169
- using LinearRecurrent = LinearRecurrentType<arma::mat, NoRegularizer>;
170
-
171
182
  } // namespace mlpack
172
183
 
173
184
  #include "linear_recurrent_impl.hpp"