mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -34,28 +34,28 @@ DDPG<
34
34
  NoiseType,
35
35
  UpdaterType,
36
36
  ReplayType
37
- >::DDPG(TrainingConfig& config,
38
- QNetworkType& learningQNetwork,
39
- PolicyNetworkType& policyNetwork,
40
- NoiseType& noise,
41
- ReplayType& replayMethod,
42
- UpdaterType qNetworkUpdater,
43
- UpdaterType policyNetworkUpdater,
44
- EnvironmentType environment):
45
- config(config),
46
- learningQNetwork(learningQNetwork),
47
- policyNetwork(policyNetwork),
48
- noise(noise),
49
- replayMethod(replayMethod),
50
- qNetworkUpdater(std::move(qNetworkUpdater)),
37
+ >::DDPG(TrainingConfig& configIn,
38
+ QNetworkType& learningQNetworkIn,
39
+ PolicyNetworkType& policyNetworkIn,
40
+ NoiseType& noiseIn,
41
+ ReplayType& replayMethodIn,
42
+ UpdaterType qNetworkUpdaterIn,
43
+ UpdaterType policyNetworkUpdaterIn,
44
+ EnvironmentType environmentIn):
45
+ config(configIn),
46
+ learningQNetwork(learningQNetworkIn),
47
+ policyNetwork(policyNetworkIn),
48
+ noise(noiseIn),
49
+ replayMethod(replayMethodIn),
50
+ qNetworkUpdater(std::move(qNetworkUpdaterIn)),
51
51
  #if ENS_VERSION_MAJOR >= 2
52
52
  qNetworkUpdatePolicy(NULL),
53
53
  #endif
54
- policyNetworkUpdater(std::move(policyNetworkUpdater)),
54
+ policyNetworkUpdater(std::move(policyNetworkUpdaterIn)),
55
55
  #if ENS_VERSION_MAJOR >= 2
56
56
  policyNetworkUpdatePolicy(NULL),
57
57
  #endif
58
- environment(std::move(environment)),
58
+ environment(std::move(environmentIn)),
59
59
  totalSteps(0),
60
60
  deterministic(false)
61
61
  {
@@ -121,7 +121,7 @@ class Acrobot
121
121
  Acrobot(const size_t maxSteps = 500,
122
122
  const double gravity = 9.81,
123
123
  const double linkLength1 = 1.0,
124
- const double linkLength2 = 1.0,
124
+ const double /* linkLength2 */ = 1.0,
125
125
  const double linkMass1 = 1.0,
126
126
  const double linkMass2 = 1.0,
127
127
  const double linkCom1 = 0.5,
@@ -134,7 +134,7 @@ class Acrobot
134
134
  maxSteps(maxSteps),
135
135
  gravity(gravity),
136
136
  linkLength1(linkLength1),
137
- linkLength2(linkLength2),
137
+ //linkLength2(linkLength2),
138
138
  linkMass1(linkMass1),
139
139
  linkMass2(linkMass2),
140
140
  linkCom1(linkCom1),
@@ -360,8 +360,8 @@ class Acrobot
360
360
  //! Locally-stored length of link 1.
361
361
  double linkLength1;
362
362
 
363
- //! Locally-stored length of link 2.
364
- double linkLength2;
363
+ //! Locally-stored length of link 2. (NOTE: not currently used)
364
+ //double linkLength2;
365
365
 
366
366
  //! Locally-stored mass of link 1.
367
367
  double linkMass1;
@@ -125,7 +125,7 @@ class CartPole
125
125
  const double doneReward = 1.0) :
126
126
  maxSteps(maxSteps),
127
127
  gravity(gravity),
128
- massCart(massCart),
128
+ //massCart(massCart),
129
129
  massPole(massPole),
130
130
  totalMass(massCart + massPole),
131
131
  length(length),
@@ -247,8 +247,8 @@ class CartPole
247
247
  //! Locally-stored gravity.
248
248
  double gravity;
249
249
 
250
- //! Locally-stored mass of the cart.
251
- double massCart;
250
+ //! Locally-stored mass of the cart. NOTE: not currently used.
251
+ //double massCart;
252
252
 
253
253
  //! Locally-stored mass of the pole.
254
254
  double massPole;
@@ -104,7 +104,8 @@ class ContinuousDoublePoleCart
104
104
  * @param l2 The length of the second pole.
105
105
  * @param gravity The gravity constant.
106
106
  * @param massCart The mass of the cart.
107
- * @param forceMag The magnitude of the applied force.
107
+ * @param forceMag The magnitude of the applied force. NOTE: not currently
108
+ * used.
108
109
  * @param tau The time interval.
109
110
  * @param thetaThresholdRadians The maximum angle.
110
111
  * @param xThreshold The maximum position.
@@ -118,7 +119,7 @@ class ContinuousDoublePoleCart
118
119
  const double l2 = 0.05,
119
120
  const double gravity = 9.8,
120
121
  const double massCart = 1.0,
121
- const double forceMag = 10.0,
122
+ const double /* forceMag */ = 10.0,
122
123
  const double tau = 0.02,
123
124
  const double thetaThresholdRadians = 36 * 2 *
124
125
  3.1416 / 360,
@@ -131,7 +132,7 @@ class ContinuousDoublePoleCart
131
132
  l2(l2),
132
133
  gravity(gravity),
133
134
  massCart(massCart),
134
- forceMag(forceMag),
135
+ //forceMag(forceMag),
135
136
  tau(tau),
136
137
  thetaThresholdRadians(thetaThresholdRadians),
137
138
  xThreshold(xThreshold),
@@ -340,8 +341,8 @@ class ContinuousDoublePoleCart
340
341
  //! Locally-stored mass of the cart.
341
342
  double massCart;
342
343
 
343
- //! Locally-stored magnitude of the applied force.
344
- double forceMag;
344
+ //! Locally-stored magnitude of the applied force. NOTE: not currently used.
345
+ //double forceMag;
345
346
 
346
347
  //! Locally-stored time interval.
347
348
  double tau;
@@ -111,18 +111,19 @@ class Pendulum
111
111
  * @param maxAngularVelocity Maximum angular velocity.
112
112
  * @param maxTorque Maximum torque.
113
113
  * @param dt The differential value.
114
- * @param doneReward The reward recieved by the agent on success.
114
+ * @param doneReward The reward recieved by the agent on success. NOTE: not
115
+ * currently used.
115
116
  */
116
117
  Pendulum(const size_t maxSteps = 200,
117
118
  const double maxAngularVelocity = 8,
118
119
  const double maxTorque = 2.0,
119
120
  const double dt = 0.05,
120
- const double doneReward = 0.0) :
121
+ const double /* doneReward */ = 0.0) :
121
122
  maxSteps(maxSteps),
122
123
  maxAngularVelocity(maxAngularVelocity),
123
124
  maxTorque(maxTorque),
124
125
  dt(dt),
125
- doneReward(doneReward),
126
+ //doneReward(doneReward),
126
127
  stepsPerformed(0)
127
128
  { /* Nothing to do here */ }
128
129
 
@@ -254,8 +255,8 @@ class Pendulum
254
255
  //! Locally-stored dt.
255
256
  double dt;
256
257
 
257
- //! Locally-stored done reward.
258
- double doneReward;
258
+ //! Locally-stored done reward. NOTE: not currently used.
259
+ //double doneReward;
259
260
 
260
261
  //! Locally-stored number of steps performed.
261
262
  size_t stepsPerformed;
@@ -35,9 +35,9 @@ class AggregatedPolicy
35
35
  * User should make sure its size is same as the number of policies
36
36
  * and the sum of its element is equal to 1.
37
37
  */
38
- AggregatedPolicy(std::vector<PolicyType> policies,
38
+ AggregatedPolicy(std::vector<PolicyType> policiesIn,
39
39
  const arma::colvec& distribution) :
40
- policies(std::move(policies)),
40
+ policies(std::move(policiesIn)),
41
41
  sampler({distribution})
42
42
  { /* Nothing to do here. */ };
43
43
 
@@ -29,21 +29,21 @@ QLearning<
29
29
  UpdaterType,
30
30
  PolicyType,
31
31
  ReplayType
32
- >::QLearning(TrainingConfig& config,
32
+ >::QLearning(TrainingConfig& configIn,
33
33
  NetworkType& network,
34
- PolicyType& policy,
35
- ReplayType& replayMethod,
36
- UpdaterType updater,
37
- EnvironmentType environment):
38
- config(config),
34
+ PolicyType& policyIn,
35
+ ReplayType& replayMethodIn,
36
+ UpdaterType updaterIn,
37
+ EnvironmentType environmentIn):
38
+ config(configIn),
39
39
  learningNetwork(network),
40
- policy(policy),
41
- replayMethod(replayMethod),
42
- updater(std::move(updater)),
40
+ policy(policyIn),
41
+ replayMethod(replayMethodIn),
42
+ updater(std::move(updaterIn)),
43
43
  #if ENS_VERSION_MAJOR >= 2
44
44
  updatePolicy(NULL),
45
45
  #endif
46
- environment(std::move(environment)),
46
+ environment(std::move(environmentIn)),
47
47
  totalSteps(0),
48
48
  deterministic(false)
49
49
  {
@@ -78,21 +78,23 @@ class CategoricalDQN
78
78
  vMax(config.VMax()),
79
79
  isNoisy(isNoisy)
80
80
  {
81
- network.Add(new Linear(h1));
82
- network.Add(new ReLU());
81
+ network.template Add<Linear>(h1);
82
+ network.template Add<ReLU>();
83
83
  if (isNoisy)
84
84
  {
85
- noisyLayerIndex.push_back(network.Network().size());
86
- network.Add(new NoisyLinear(h2));
87
- network.Add(new ReLU());
88
- noisyLayerIndex.push_back(network.Network().size());
89
- network.Add(new NoisyLinear(outputDim * atomSize));
85
+ network.template Add<NoisyLinear>(h2);
86
+ noisyLayers.push_back(
87
+ dynamic_cast<NoisyLinear<>*>(network.Network().back()));
88
+ network.template Add<ReLU>();
89
+ network.template Add<NoisyLinear>(outputDim * atomSize);
90
+ noisyLayers.push_back(
91
+ dynamic_cast<NoisyLinear<>*>(network.Network().back()));
90
92
  }
91
93
  else
92
94
  {
93
- network.Add(new Linear(h2));
94
- network.Add(new ReLU());
95
- network.Add(new Linear(outputDim * atomSize));
95
+ network.template Add<Linear>(h2);
96
+ network.template Add<ReLU>();
97
+ network.template Add<Linear>(outputDim * atomSize);
96
98
  }
97
99
  }
98
100
 
@@ -104,16 +106,19 @@ class CategoricalDQN
104
106
  * @param config Hyper-parameters for categorical dqn.
105
107
  * @param isNoisy Specifies whether the network needs to be of type noisy.
106
108
  */
107
- CategoricalDQN(NetworkType& network,
109
+ CategoricalDQN(NetworkType& networkIn,
108
110
  TrainingConfig config,
109
111
  const bool isNoisy = false):
110
- network(std::move(network)),
112
+ network(std::move(networkIn)),
111
113
  atomSize(config.AtomSize()),
112
114
  vMin(config.VMin()),
113
115
  vMax(config.VMax()),
114
116
  isNoisy(isNoisy)
115
117
  { /* Nothing to do here. */ }
116
118
 
119
+ // TODO: implement copy constructor and operator
120
+ CategoricalDQN(const CategoricalDQN& other) = delete;
121
+
117
122
  /**
118
123
  * Predict the responses to a given set of predictors. The responses will
119
124
  * reflect the output of the given output layer as returned by the
@@ -176,10 +181,9 @@ class CategoricalDQN
176
181
  */
177
182
  void ResetNoise()
178
183
  {
179
- for (size_t i = 0; i < noisyLayerIndex.size(); ++i)
184
+ for (size_t i = 0; i < noisyLayers.size(); ++i)
180
185
  {
181
- dynamic_cast<NoisyLinear*>(
182
- (network.Network()[noisyLayerIndex[i]]))->ResetNoise();
186
+ noisyLayers[i]->ResetNoise();
183
187
  }
184
188
  }
185
189
 
@@ -228,10 +232,10 @@ class CategoricalDQN
228
232
  bool isNoisy;
229
233
 
230
234
  //! Locally-stored indexes of noisy layers in the network.
231
- std::vector<size_t> noisyLayerIndex;
235
+ std::vector<NoisyLinear<>*> noisyLayers;
232
236
 
233
237
  //! Locally-stored softmax activation function.
234
- Softmax softMax;
238
+ Softmax<> softMax;
235
239
 
236
240
  //! Locally-stored activations from softMax.
237
241
  arma::mat activations;
@@ -56,17 +56,15 @@ class DuelingDQN
56
56
  //! Default constructor.
57
57
  DuelingDQN() : isNoisy(false)
58
58
  {
59
- // TODO: this really ought to use a DAG network, but that's not implemented
60
- // yet.
61
- featureNetwork = new MultiLayer<arma::mat>();
62
- valueNetwork = new MultiLayer<arma::mat>();
63
- advantageNetwork = new MultiLayer<arma::mat>();
64
- concat = new Concat();
59
+ MultiLayer<arma::mat> featureNetwork;
60
+ MultiLayer<arma::mat> valueNetwork;
61
+ MultiLayer<arma::mat> advantageNetwork;
62
+ Concat concat;
65
63
 
66
- concat->Add(valueNetwork);
67
- concat->Add(advantageNetwork);
64
+ concat.Add(std::move(valueNetwork));
65
+ concat.Add(std::move(advantageNetwork));
68
66
  completeNetwork.Add(featureNetwork);
69
- completeNetwork.Add(concat);
67
+ completeNetwork.Add(std::move(concat));
70
68
  }
71
69
 
72
70
  /**
@@ -88,43 +86,52 @@ class DuelingDQN
88
86
  completeNetwork(outputLayer, init),
89
87
  isNoisy(isNoisy)
90
88
  {
91
- featureNetwork = new MultiLayer<arma::mat>();
92
- featureNetwork->Add(new Linear(h1));
93
- featureNetwork->Add(new ReLU());
89
+ // TODO: this really ought to use a DAG network, but that's not implemented
90
+ // yet.
91
+ MultiLayer<arma::mat> featureNetwork;
92
+ featureNetwork.template Add<Linear>(h1);
93
+ featureNetwork.template Add<ReLU<>>();
94
94
 
95
- valueNetwork = new MultiLayer<arma::mat>();
96
- advantageNetwork = new MultiLayer<arma::mat>();
95
+ MultiLayer<arma::mat> valueNetwork;
96
+ MultiLayer<arma::mat> advantageNetwork;
97
97
 
98
98
  if (isNoisy)
99
99
  {
100
- noisyLayerIndex.push_back(valueNetwork->Network().size());
101
- valueNetwork->Add(new NoisyLinear(h2));
102
- advantageNetwork->Add(new NoisyLinear(h2));
103
-
104
- valueNetwork->Add(new ReLU());
105
- advantageNetwork->Add(new ReLU());
106
-
107
- noisyLayerIndex.push_back(valueNetwork->Network().size());
108
- valueNetwork->Add(new NoisyLinear(1));
109
- advantageNetwork->Add(new NoisyLinear(outputDim));
100
+ valueNetwork.Add<NoisyLinear>(h2);
101
+ noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
102
+ valueNetwork.Network().back()));
103
+
104
+ advantageNetwork.Add<NoisyLinear>(h2);
105
+ noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
106
+ advantageNetwork.Network().back()));
107
+
108
+ valueNetwork.template Add<ReLU>();
109
+ advantageNetwork.template Add<ReLU>();
110
+
111
+ valueNetwork.template Add<NoisyLinear>(1);
112
+ noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
113
+ valueNetwork.Network().back()));
114
+ advantageNetwork.template Add<NoisyLinear>(outputDim);
115
+ noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
116
+ advantageNetwork.Network().back()));
110
117
  }
111
118
  else
112
119
  {
113
- valueNetwork->Add(new Linear(h2));
114
- valueNetwork->Add(new ReLU());
115
- valueNetwork->Add(new Linear(1));
120
+ valueNetwork.template Add<Linear>(h2);
121
+ valueNetwork.template Add<ReLU>();
122
+ valueNetwork.template Add<Linear>(1);
116
123
 
117
- advantageNetwork->Add(new Linear(h2));
118
- advantageNetwork->Add(new ReLU());
119
- advantageNetwork->Add(new Linear(outputDim));
124
+ advantageNetwork.template Add<Linear>(h2);
125
+ advantageNetwork.template Add<ReLU>();
126
+ advantageNetwork.template Add<Linear>(outputDim);
120
127
  }
121
128
 
122
- concat = new Concat();
123
- concat->Add(valueNetwork);
124
- concat->Add(advantageNetwork);
129
+ Concat concat;
130
+ concat.Add(std::move(valueNetwork));
131
+ concat.Add(std::move(advantageNetwork));
125
132
 
126
- completeNetwork.Add(featureNetwork);
127
- completeNetwork.Add(concat);
133
+ completeNetwork.Add(std::move(featureNetwork));
134
+ completeNetwork.Add(std::move(concat));
128
135
  }
129
136
 
130
137
  /**
@@ -135,35 +142,35 @@ class DuelingDQN
135
142
  * @param valueNetwork The value network to be used by DuelingDQN class.
136
143
  * @param isNoisy Specifies whether the network needs to be of type noisy.
137
144
  */
138
- DuelingDQN(FeatureNetworkType& featureNetwork,
139
- AdvantageNetworkType& advantageNetwork,
140
- ValueNetworkType& valueNetwork,
145
+ DuelingDQN(FeatureNetworkType&& featureNetwork,
146
+ AdvantageNetworkType&& advantageNetwork,
147
+ ValueNetworkType&& valueNetwork,
141
148
  const bool isNoisy = false):
142
- featureNetwork(featureNetwork),
143
- advantageNetwork(advantageNetwork),
144
- valueNetwork(valueNetwork),
145
149
  isNoisy(isNoisy)
146
150
  {
147
- concat = new Concat();
148
- concat->Add(valueNetwork);
149
- concat->Add(advantageNetwork);
150
- completeNetwork.Add(featureNetwork);
151
- completeNetwork.Add(concat);
151
+ Concat concat;
152
+ concat.Add(std::move(valueNetwork));
153
+ concat.Add(std::move(advantageNetwork));
154
+ completeNetwork.Add(std::move(featureNetwork));
155
+ completeNetwork.Add(std::move(concat));
152
156
  }
153
157
 
154
- //! Copy constructor.
155
- DuelingDQN(const DuelingDQN& /* model */) : isNoisy(false)
156
- { /* Nothing to do here. */ }
158
+ // Copy constructor.
159
+ //DuelingDQN(const DuelingDQN& model) : isNoisy(false)
160
+ // {
161
+ // // Use copy operator.
162
+ // *this = model;
163
+ // }
157
164
 
158
- //! Copy assignment operator.
159
- void operator = (const DuelingDQN& model)
160
- {
161
- *valueNetwork = *model.valueNetwork;
162
- *advantageNetwork = *model.advantageNetwork;
163
- *featureNetwork = *model.featureNetwork;
164
- isNoisy = model.isNoisy;
165
- noisyLayerIndex = model.noisyLayerIndex;
166
- }
165
+ // Copy assignment operator.
166
+ // void operator=(const DuelingDQN& model)
167
+ // {
168
+ // completeNetwork = model.completeNetwork;
169
+
170
+ // isNoisy = model.isNoisy;
171
+ // }
172
+
173
+ DuelingDQN(const DuelingDQN& model) = delete;
167
174
 
168
175
  /**
169
176
  * Predict the responses to a given set of predictors. The responses will
@@ -234,12 +241,9 @@ class DuelingDQN
234
241
  */
235
242
  void ResetNoise()
236
243
  {
237
- for (size_t i = 0; i < noisyLayerIndex.size(); i++)
244
+ for (size_t i = 0; i < noisyLayers.size(); i++)
238
245
  {
239
- dynamic_cast<NoisyLinear*>(
240
- (valueNetwork->Network()[noisyLayerIndex[i]]))->ResetNoise();
241
- dynamic_cast<NoisyLinear*>(
242
- (advantageNetwork->Network()[noisyLayerIndex[i]]))->ResetNoise();
246
+ noisyLayers[i]->ResetNoise();
243
247
  }
244
248
  }
245
249
 
@@ -252,24 +256,12 @@ class DuelingDQN
252
256
  //! Locally-stored complete network.
253
257
  CompleteNetworkType completeNetwork;
254
258
 
255
- //! Locally-stored concat network.
256
- Concat* concat;
257
-
258
- //! Locally-stored feature network.
259
- FeatureNetworkType* featureNetwork;
260
-
261
- //! Locally-stored advantage network.
262
- AdvantageNetworkType* advantageNetwork;
263
-
264
- //! Locally-stored value network.
265
- ValueNetworkType* valueNetwork;
259
+ // Pointers to noisy layers.
260
+ std::vector<NoisyLinear<>*> noisyLayers;
266
261
 
267
262
  //! Locally-stored check for noisy network.
268
263
  bool isNoisy;
269
264
 
270
- //! Locally-stored indexes of noisy layers in the network.
271
- std::vector<size_t> noisyLayerIndex;
272
-
273
265
  //! Locally-stored actionValues of the network.
274
266
  arma::mat actionValues;
275
267
 
@@ -58,21 +58,21 @@ class SimpleDQN
58
58
  network(outputLayer, init),
59
59
  isNoisy(isNoisy)
60
60
  {
61
- network.Add(new Linear(h1));
62
- network.Add(new ReLU());
61
+ network.template Add<Linear>(h1);
62
+ network.template Add<ReLU>();
63
63
  if (isNoisy)
64
64
  {
65
65
  noisyLayerIndex.push_back(network.Network().size());
66
- network.Add(new NoisyLinear(h2));
67
- network.Add(new ReLU());
66
+ network.template Add<NoisyLinear>(h2);
67
+ network.template Add<ReLU>();
68
68
  noisyLayerIndex.push_back(network.Network().size());
69
- network.Add(new NoisyLinear(outputDim));
69
+ network.template Add<NoisyLinear>(outputDim);
70
70
  }
71
71
  else
72
72
  {
73
- network.Add(new Linear(h2));
74
- network.Add(new ReLU());
75
- network.Add(new Linear(outputDim));
73
+ network.template Add<Linear>(h2);
74
+ network.template Add<ReLU>();
75
+ network.template Add<Linear>(outputDim);
76
76
  }
77
77
  }
78
78
 
@@ -129,7 +129,7 @@ class SimpleDQN
129
129
  {
130
130
  for (size_t i = 0; i < noisyLayerIndex.size(); i++)
131
131
  {
132
- dynamic_cast<NoisyLinear*>(
132
+ dynamic_cast<NoisyLinear<>*>(
133
133
  network.Network()[noisyLayerIndex[i]])->ResetNoise();
134
134
  }
135
135
  }
@@ -32,26 +32,26 @@ SAC<
32
32
  PolicyNetworkType,
33
33
  UpdaterType,
34
34
  ReplayType
35
- >::SAC(TrainingConfig& config,
36
- QNetworkType& learningQ1Network,
37
- PolicyNetworkType& policyNetwork,
38
- ReplayType& replayMethod,
39
- UpdaterType qNetworkUpdater,
40
- UpdaterType policyNetworkUpdater,
41
- EnvironmentType environment):
42
- config(config),
43
- learningQ1Network(learningQ1Network),
44
- policyNetwork(policyNetwork),
45
- replayMethod(replayMethod),
46
- qNetworkUpdater(std::move(qNetworkUpdater)),
35
+ >::SAC(TrainingConfig& configIn,
36
+ QNetworkType& learningQ1NetworkIn,
37
+ PolicyNetworkType& policyNetworkIn,
38
+ ReplayType& replayMethodIn,
39
+ UpdaterType qNetworkUpdaterIn,
40
+ UpdaterType policyNetworkUpdaterIn,
41
+ EnvironmentType environmentIn):
42
+ config(configIn),
43
+ learningQ1Network(learningQ1NetworkIn),
44
+ policyNetwork(policyNetworkIn),
45
+ replayMethod(replayMethodIn),
46
+ qNetworkUpdater(std::move(qNetworkUpdaterIn)),
47
47
  #if ENS_VERSION_MAJOR >= 2
48
48
  qNetworkUpdatePolicy(NULL),
49
49
  #endif
50
- policyNetworkUpdater(std::move(policyNetworkUpdater)),
50
+ policyNetworkUpdater(std::move(policyNetworkUpdaterIn)),
51
51
  #if ENS_VERSION_MAJOR >= 2
52
52
  policyNetworkUpdatePolicy(NULL),
53
53
  #endif
54
- environment(std::move(environment)),
54
+ environment(std::move(environmentIn)),
55
55
  totalSteps(0),
56
56
  deterministic(false)
57
57
  {
@@ -32,26 +32,26 @@ TD3<
32
32
  PolicyNetworkType,
33
33
  UpdaterType,
34
34
  ReplayType
35
- >::TD3(TrainingConfig& config,
36
- QNetworkType& learningQ1Network,
37
- PolicyNetworkType& policyNetwork,
38
- ReplayType& replayMethod,
39
- UpdaterType qNetworkUpdater,
40
- UpdaterType policyNetworkUpdater,
41
- EnvironmentType environment):
42
- config(config),
43
- learningQ1Network(learningQ1Network),
44
- policyNetwork(policyNetwork),
45
- replayMethod(replayMethod),
46
- qNetworkUpdater(std::move(qNetworkUpdater)),
35
+ >::TD3(TrainingConfig& configIn,
36
+ QNetworkType& learningQ1NetworkIn,
37
+ PolicyNetworkType& policyNetworkIn,
38
+ ReplayType& replayMethodIn,
39
+ UpdaterType qNetworkUpdaterIn,
40
+ UpdaterType policyNetworkUpdaterIn,
41
+ EnvironmentType environmentIn):
42
+ config(configIn),
43
+ learningQ1Network(learningQ1NetworkIn),
44
+ policyNetwork(policyNetworkIn),
45
+ replayMethod(replayMethodIn),
46
+ qNetworkUpdater(std::move(qNetworkUpdaterIn)),
47
47
  #if ENS_VERSION_MAJOR >= 2
48
48
  qNetworkUpdatePolicy(NULL),
49
49
  #endif
50
- policyNetworkUpdater(std::move(policyNetworkUpdater)),
50
+ policyNetworkUpdater(std::move(policyNetworkUpdaterIn)),
51
51
  #if ENS_VERSION_MAJOR >= 2
52
52
  policyNetworkUpdatePolicy(NULL),
53
53
  #endif
54
- environment(std::move(environment)),
54
+ environment(std::move(environmentIn)),
55
55
  totalSteps(0),
56
56
  deterministic(false)
57
57
  {
@@ -335,7 +335,7 @@ inline void SoftmaxRegressionFunction<MatType>::PartialGradient(
335
335
  const size_t j,
336
336
  GradType& gradient) const
337
337
  {
338
- gradient.zeros(arma::size(parameters));
338
+ gradient.zeros(size(parameters));
339
339
 
340
340
  DenseMatType probabilities;
341
341
  GetProbabilitiesMatrix(parameters, probabilities, 0, data.n_cols);
@@ -451,7 +451,7 @@ inline double ParallelSGD<ExponentialBackoff>::Optimize(
451
451
  // Get the stepsize for this iteration
452
452
  double stepSize = decayPolicy.StepSize(i);
453
453
 
454
- if (shuffle) // Determine order of visitation.
454
+ if (Shuffle()) // Determine order of visitation.
455
455
  std::shuffle(visitationOrder.begin(), visitationOrder.end(),
456
456
  mlpack::RandGen());
457
457
 
@@ -31,6 +31,7 @@ namespace adaboost { using namespace mlpack; }
31
31
  namespace amf { using namespace mlpack; }
32
32
  namespace ann { using namespace mlpack; }
33
33
  namespace cf { using namespace mlpack; }
34
+ namespace data { using namespace mlpack; }
34
35
  namespace dbscan { using namespace mlpack; }
35
36
  namespace det { using namespace mlpack; }
36
37
  namespace emst { using namespace mlpack; }
@@ -32,6 +32,7 @@
32
32
  #include <mlpack/core/cereal/pointer_vector_wrapper.hpp>
33
33
  #include <mlpack/core/cereal/pointer_wrapper.hpp>
34
34
  #include <mlpack/core/cereal/template_class_version.hpp>
35
+ #include <mlpack/core/cereal/low_precision.hpp>
35
36
  #include <mlpack/core/data/has_serialize.hpp>
36
37
 
37
38
  // Include ready to use utility function to check sizes of datasets.