mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -313,7 +313,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
313
313
  // Now, do the training.
314
314
  if (params.Has("training"))
315
315
  {
316
- data::NormalizeLabels(rawLabels, labels, model->mappings);
316
+ NormalizeLabels(rawLabels, labels, model->mappings);
317
317
  numClasses = params.Get<int>("num_classes") == 0 ?
318
318
  model->mappings.n_elem : params.Get<int>("num_classes");
319
319
  model->svm.Lambda() = lambda;
@@ -410,7 +410,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
410
410
  }
411
411
 
412
412
  model->svm.Classify(testSet, predictedLabels);
413
- data::RevertLabels(predictedLabels, model->mappings, predictions);
413
+ RevertLabels(predictedLabels, model->mappings, predictions);
414
414
 
415
415
  // Calculate accuracy, if desired.
416
416
  if (params.Has("test_labels"))
@@ -419,7 +419,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
419
419
  arma::Row<size_t> testRawLabels =
420
420
  std::move(params.Get<arma::Row<size_t>>("test_labels"));
421
421
 
422
- data::NormalizeLabels(testRawLabels, testLabels, model->mappings);
422
+ NormalizeLabels(testRawLabels, testLabels, model->mappings);
423
423
 
424
424
  if (testSet.n_cols != testLabels.n_elem)
425
425
  {
@@ -326,7 +326,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
326
326
  // Now, normalize the labels.
327
327
  arma::Col<size_t> mappings;
328
328
  arma::Row<size_t> labels;
329
- data::NormalizeLabels(rawLabels, labels, mappings);
329
+ NormalizeLabels(rawLabels, labels, mappings);
330
330
 
331
331
  arma::mat distance;
332
332
 
@@ -183,7 +183,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
183
183
  if (params.Has("reference"))
184
184
  {
185
185
  // Workaround: this avoids printing load information twice for the CLI
186
- // bindings, where GetPrintable() will trigger a call to data::Load(),
186
+ // bindings, where GetPrintable() will trigger a call to Load(),
187
187
  // which prints loading information in the middle of the Log::Info
188
188
  // message.
189
189
  (void) params.Get<arma::mat>("reference");
@@ -81,7 +81,7 @@ inline void MatrixCompletion::CheckValues()
81
81
  if (indices(0, i) >= m || indices(1, i) >= n)
82
82
  Log::Fatal << "MatrixCompletion::CheckValues(): indices ("
83
83
  << indices(0, i) << ", " << indices(1, i)
84
- << ") are out of bounds for matrix of size " << m << " x n!"
84
+ << ") are out of bounds for matrix of size " << m << " x " << n << "!"
85
85
  << std::endl;
86
86
  }
87
87
  }
@@ -364,7 +364,7 @@ void NaiveBayesClassifier<ModelMatType>::Classify(
364
364
  ModelMatType logLikelihoods;
365
365
  LogLikelihood(data, logLikelihoods);
366
366
 
367
- predictionProbs.set_size(arma::size(logLikelihoods));
367
+ predictionProbs.set_size(size(logLikelihoods));
368
368
  double maxValue, logProbX;
369
369
  for (size_t j = 0; j < data.n_cols; ++j)
370
370
  {
@@ -152,14 +152,14 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
152
152
  {
153
153
  // Load labels.
154
154
  Row<size_t> rawLabels = std::move(params.Get<Row<size_t>>("labels"));
155
- data::NormalizeLabels(rawLabels, labels, model->mappings);
155
+ NormalizeLabels(rawLabels, labels, model->mappings);
156
156
  }
157
157
  else
158
158
  {
159
159
  // Use the last row of the training data as the labels.
160
160
  Log::Info << "Using last dimension of training data as training labels."
161
161
  << endl;
162
- data::NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
162
+ NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
163
163
  model->mappings);
164
164
  // Remove the label row.
165
165
  trainingData.shed_row(trainingData.n_rows - 1);
@@ -200,7 +200,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
200
200
  {
201
201
  // Un-normalize labels to prepare output.
202
202
  Row<size_t> rawResults;
203
- data::RevertLabels(predictions, model->mappings, rawResults);
203
+ RevertLabels(predictions, model->mappings, rawResults);
204
204
 
205
205
  if (params.Has("predictions"))
206
206
  params.Get<Row<size_t>>("predictions") = std::move(rawResults);
@@ -217,7 +217,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
217
217
  // Now, normalize the labels.
218
218
  arma::Col<size_t> mappings;
219
219
  arma::Row<size_t> labels;
220
- data::NormalizeLabels(rawLabels, labels, mappings);
220
+ NormalizeLabels(rawLabels, labels, mappings);
221
221
 
222
222
  arma::mat distance;
223
223
 
@@ -179,16 +179,16 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
179
179
  const string algorithm = params.Get<string>("algorithm");
180
180
  RequireParamInSet<string>(params, "algorithm", { "naive", "single_tree",
181
181
  "dual_tree", "greedy" }, true, "unknown neighbor search algorithm");
182
- NeighborSearchMode searchMode = DUAL_TREE_MODE;
182
+ NeighborSearchStrategy searchStrategy = DUAL_TREE;
183
183
 
184
184
  if (algorithm == "naive")
185
- searchMode = NAIVE_MODE;
185
+ searchStrategy = NAIVE;
186
186
  else if (algorithm == "single_tree")
187
- searchMode = SINGLE_TREE_MODE;
187
+ searchStrategy = SINGLE_TREE;
188
188
  else if (algorithm == "dual_tree")
189
- searchMode = DUAL_TREE_MODE;
189
+ searchStrategy = DUAL_TREE;
190
190
  else if (algorithm == "greedy")
191
- searchMode = GREEDY_SINGLE_TREE_MODE;
191
+ searchStrategy = GREEDY_SINGLE_TREE;
192
192
 
193
193
  if (params.Has("reference"))
194
194
  {
@@ -240,7 +240,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
240
240
  Log::Info << "Using reference data from "
241
241
  << params.GetPrintable<arma::mat>("reference") << "." << endl;
242
242
 
243
- kfn->BuildModel(timers, std::move(referenceSet), searchMode, epsilon);
243
+ kfn->BuildModel(timers, std::move(referenceSet), searchStrategy, epsilon);
244
244
  }
245
245
  else
246
246
  {
@@ -248,7 +248,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
248
248
  kfn = params.Get<KFNModel*>("input_model");
249
249
 
250
250
  // Adjust search mode.
251
- kfn->SearchMode() = searchMode;
251
+ kfn->SearchStrategy() = searchStrategy;
252
252
  kfn->Epsilon() = epsilon;
253
253
 
254
254
  // If leaf_size wasn't provided, let's consider the current value in the
@@ -272,7 +272,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
272
272
  if (params.Has("query"))
273
273
  {
274
274
  // Workaround: this avoids printing load information twice for the CLI
275
- // bindings, where GetPrintable() will trigger a call to data::Load(),
275
+ // bindings, where GetPrintable() will trigger a call to Load(),
276
276
  // which prints loading information in the middle of the Log::Info
277
277
  // message.
278
278
  (void) params.Get<arma::mat>("query");
@@ -187,16 +187,16 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
187
187
  const string algorithm = params.Get<string>("algorithm");
188
188
  RequireParamInSet<string>(params, "algorithm", { "naive", "single_tree",
189
189
  "dual_tree", "greedy" }, true, "unknown neighbor search algorithm");
190
- NeighborSearchMode searchMode = DUAL_TREE_MODE;
190
+ NeighborSearchStrategy searchStrategy = DUAL_TREE;
191
191
 
192
192
  if (algorithm == "naive")
193
- searchMode = NAIVE_MODE;
193
+ searchStrategy = NAIVE;
194
194
  else if (algorithm == "single_tree")
195
- searchMode = SINGLE_TREE_MODE;
195
+ searchStrategy = SINGLE_TREE;
196
196
  else if (algorithm == "dual_tree")
197
- searchMode = DUAL_TREE_MODE;
197
+ searchStrategy = DUAL_TREE;
198
198
  else if (algorithm == "greedy")
199
- searchMode = GREEDY_SINGLE_TREE_MODE;
199
+ searchStrategy = GREEDY_SINGLE_TREE;
200
200
 
201
201
  if (params.Has("reference"))
202
202
  {
@@ -253,7 +253,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
253
253
  Log::Info << "Using reference data from "
254
254
  << params.GetPrintable<arma::mat>("reference") << "." << endl;
255
255
 
256
- knn->BuildModel(timers, std::move(referenceSet), searchMode, epsilon);
256
+ knn->BuildModel(timers, std::move(referenceSet), searchStrategy, epsilon);
257
257
  }
258
258
  else
259
259
  {
@@ -261,7 +261,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
261
261
  knn = params.Get<KNNModel*>("input_model");
262
262
 
263
263
  // Adjust search mode.
264
- knn->SearchMode() = searchMode;
264
+ knn->SearchStrategy() = searchStrategy;
265
265
  knn->Epsilon() = epsilon;
266
266
 
267
267
  // If leaf_size wasn't provided, let's consider the current value in the
@@ -285,7 +285,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
285
285
  if (params.Has("query"))
286
286
  {
287
287
  // Workaround: this avoids printing load information twice for the CLI
288
- // bindings, where GetPrintable() will trigger a call to data::Load(),
288
+ // bindings, where GetPrintable() will trigger a call to Load(),
289
289
  // which prints loading information in the middle of the Log::Info
290
290
  // message.
291
291
  (void) params.Get<arma::mat>("query");
@@ -16,8 +16,6 @@
16
16
  #include <mlpack/core.hpp>
17
17
 
18
18
  #include "neighbor_search_stat.hpp"
19
- #include "sort_policies/nearest_neighbor_sort.hpp"
20
- #include "sort_policies/furthest_neighbor_sort.hpp"
21
19
  #include "neighbor_search_rules.hpp"
22
20
  #include "unmap.hpp"
23
21
 
@@ -32,7 +30,17 @@ template<typename SortPolicy,
32
30
  template<typename RuleType> class SingleTreeTraversalType>
33
31
  class LeafSizeNSWrapper;
34
32
 
35
- //! NeighborSearchMode represents the different neighbor search modes available.
33
+ // NeighborSearchStrategy represents the different neighbor search strategies
34
+ // available.
35
+ enum NeighborSearchStrategy
36
+ {
37
+ NAIVE,
38
+ SINGLE_TREE,
39
+ DUAL_TREE,
40
+ GREEDY_SINGLE_TREE
41
+ };
42
+
43
+ // This is for reverse compatibility and will be removed in mlpack 5.0.0.
36
44
  enum NeighborSearchMode
37
45
  {
38
46
  NAIVE_MODE,
@@ -41,6 +49,36 @@ enum NeighborSearchMode
41
49
  GREEDY_SINGLE_TREE_MODE
42
50
  };
43
51
 
52
+ // This is for reverse compatibility and will be removed in mlpack 5.0.0.
53
+ inline NeighborSearchStrategy ModeToStrategy(const NeighborSearchMode& mode)
54
+ {
55
+ switch (mode)
56
+ {
57
+ case NAIVE_MODE: return NAIVE;
58
+ case SINGLE_TREE_MODE: return SINGLE_TREE;
59
+ case DUAL_TREE_MODE: return DUAL_TREE;
60
+ case GREEDY_SINGLE_TREE_MODE: return GREEDY_SINGLE_TREE;
61
+ }
62
+
63
+ // Fix warning.
64
+ return DUAL_TREE;
65
+ }
66
+
67
+ // This is for reverse compatibility and will be removed in mlpack 5.0.0.
68
+ inline NeighborSearchMode StrategyToMode(const NeighborSearchStrategy& strategy)
69
+ {
70
+ switch (strategy)
71
+ {
72
+ case NAIVE: return NAIVE_MODE;
73
+ case SINGLE_TREE: return SINGLE_TREE_MODE;
74
+ case DUAL_TREE: return DUAL_TREE_MODE;
75
+ case GREEDY_SINGLE_TREE: return GREEDY_SINGLE_TREE_MODE;
76
+ }
77
+
78
+ // Fix warning.
79
+ return DUAL_TREE_MODE;
80
+ }
81
+
44
82
  /**
45
83
  * The NeighborSearch class is a template class for performing distance-based
46
84
  * neighbor searches. It takes a query dataset and a reference dataset (or just
@@ -97,26 +135,34 @@ class NeighborSearch
97
135
  * pre-constructing the trees, passing std::move(yourReferenceSet).
98
136
  *
99
137
  * @param referenceSet Set of reference points.
100
- * @param mode Neighbor search mode.
138
+ * @param strategy Neighbor search strategy.
101
139
  * @param epsilon Relative approximate error (non-negative).
102
140
  * @param distance An optional instance of the DistanceType class.
103
141
  */
104
142
  NeighborSearch(MatType referenceSet,
105
- const NeighborSearchMode mode = DUAL_TREE_MODE,
143
+ const NeighborSearchStrategy strategy = DUAL_TREE,
106
144
  const double epsilon = 0,
107
145
  const DistanceType distance = DistanceType());
108
146
 
147
+ [[deprecated("Will be removed in mlpack 5.0.0. Instead of a "
148
+ "NeighborSearchMode, pass a NeighborSearchStrategy.")]]
149
+ NeighborSearch(MatType referenceSet,
150
+ const NeighborSearchMode mode,
151
+ const double epsilon = 0,
152
+ const DistanceType distance = DistanceType()) :
153
+ NeighborSearch(std::move(referenceSet), ModeToStrategy(mode), epsilon,
154
+ distance) { }
155
+
109
156
  /**
110
157
  * Initialize the NeighborSearch object with a copy of the given
111
158
  * pre-constructed reference tree (this is the tree built on the points that
112
- * will be searched). Optionally, choose to use single-tree mode. Naive mode
113
- * is not available as an option for this constructor. Additionally, an
114
- * instantiated distance metric can be given, for cases where the distance
115
- * metric holds data.
159
+ * will be searched). Optionally, choose to use a different search strategy.
160
+ * Additionally, an instantiated distance metric can be given, for cases where
161
+ * the distance metric holds data.
116
162
  *
117
163
  * This method will copy the given tree. When copies must absolutely be
118
164
  * avoided, you can avoid this copy, while taking ownership of the given tree,
119
- * by passing std::move(yourReferenceTree)
165
+ * by passing std::move(yourReferenceTree).
120
166
  *
121
167
  * @note
122
168
  * Mapping the points of the matrix back to their original indices is not done
@@ -127,12 +173,28 @@ class NeighborSearch
127
173
  * @param referenceTree Pre-built tree for reference points.
128
174
  * @param mode Neighbor search mode.
129
175
  * @param epsilon Relative approximate error (non-negative).
130
- * @param distance Instantiated distance metric.
131
176
  */
132
177
  NeighborSearch(Tree referenceTree,
133
- const NeighborSearchMode mode = DUAL_TREE_MODE,
134
- const double epsilon = 0,
135
- const DistanceType distance = DistanceType());
178
+ const NeighborSearchStrategy strategy = DUAL_TREE,
179
+ const double epsilon = 0);
180
+
181
+ [[deprecated("Will be removed in mlpack 5.0.0. Instead of a "
182
+ "NeighborSearchMode, pass a NeighborSearchStrategy.")]]
183
+ NeighborSearch(Tree referenceTree,
184
+ const NeighborSearchMode mode,
185
+ const double epsilon = 0) :
186
+ NeighborSearch(std::move(referenceTree), ModeToStrategy(mode), epsilon) {}
187
+
188
+ // This version is kept around for reverse compatibility; but, if you are
189
+ // passing a distance, you should use the overload above, which will just use
190
+ // the distance directly from the given tree.
191
+ [[deprecated("Will be removed in mlpack 5.0.0. Use the version without "
192
+ "`distance` instead (`referenceTree.Distance()` will be used as "
193
+ "the distance metric).")]]
194
+ NeighborSearch(Tree referenceTree,
195
+ const NeighborSearchMode mode,
196
+ const double epsilon,
197
+ const DistanceType distance);
136
198
 
137
199
  /**
138
200
  * Create a NeighborSearch object without any reference data. If Search() is
@@ -143,10 +205,17 @@ class NeighborSearch
143
205
  * @param epsilon Relative approximate error (non-negative).
144
206
  * @param distance Instantiated distance metric.
145
207
  */
146
- NeighborSearch(const NeighborSearchMode mode = DUAL_TREE_MODE,
208
+ NeighborSearch(const NeighborSearchStrategy strategy = DUAL_TREE,
147
209
  const double epsilon = 0,
148
210
  const DistanceType distance = DistanceType());
149
211
 
212
+ [[deprecated("Will be removed in mlpack 5.0.0. Instead of a "
213
+ "NeighborSearchMode, pass a NeighborSearchStrategy.")]]
214
+ NeighborSearch(const NeighborSearchMode mode,
215
+ const double epsilon = 0,
216
+ const DistanceType distance = DistanceType()) :
217
+ NeighborSearch(ModeToStrategy(mode), epsilon, distance) { }
218
+
150
219
  /**
151
220
  * Construct the NeighborSearch object by copying the given NeighborSearch
152
221
  * object.
@@ -213,8 +282,8 @@ class NeighborSearch
213
282
  *
214
283
  * If querySet contains only a few query points, the extra cost of building a
215
284
  * tree on the points for dual-tree search may not be warranted, and it may be
216
- * worthwhile to set singleMode = false (either in the constructor or with
217
- * SingleMode()).
285
+ * worthwhile to set mode to SINGLE_TREE_MODE (either in the constructor or
286
+ * with SearchMode()).
218
287
  *
219
288
  * @param querySet Set of query points (can be just one point).
220
289
  * @param k Number of neighbors to search for.
@@ -311,17 +380,31 @@ class NeighborSearch
311
380
  static double Recall(arma::Mat<IndexType>& foundNeighbors,
312
381
  arma::Mat<IndexType>& realNeighbors);
313
382
 
314
- //! Return the total number of base case evaluations performed during the last
315
- //! search.
383
+ // Reset all bounding quantities in a prebuilt external tree.
384
+ // When calling Search() multiple times with a prebuilt query tree, this must
385
+ // be called between each Search() invocation!
386
+ static void ResetTree(Tree& tree);
387
+
388
+ // Return the total number of base case evaluations performed during the last
389
+ // search.
316
390
  size_t BaseCases() const { return baseCases; }
317
391
 
318
- //! Return the number of node combination scores during the last search.
392
+ // Return the number of node combination scores during the last search.
319
393
  size_t Scores() const { return scores; }
320
394
 
321
- //! Access the search mode.
395
+ // Access the search mode.
396
+ [[deprecated("Will be removed in mlpack 5.0.0. Use SearchStrategy() "
397
+ "instead.")]]
322
398
  NeighborSearchMode SearchMode() const { return searchMode; }
323
- //! Modify the search mode.
324
- NeighborSearchMode& SearchMode() { return searchMode; }
399
+ // Modify the search mode.
400
+ [[deprecated("Will be removed in mlpack 5.0.0. Use SearchStrategy() "
401
+ "instead.")]]
402
+ NeighborSearchMode& SearchMode() { searchModeMod = true; return searchMode; }
403
+
404
+ // Access the search strategy.
405
+ NeighborSearchStrategy SearchStrategy() const { return searchStrategy; }
406
+ // Modify the search strategy.
407
+ NeighborSearchStrategy& SearchStrategy() { return searchStrategy; }
325
408
 
326
409
  //! Access the relative error to be considered in approximate search.
327
410
  double Epsilon() const { return epsilon; }
@@ -341,37 +424,74 @@ class NeighborSearch
341
424
  void serialize(Archive& ar, const uint32_t version);
342
425
 
343
426
  private:
344
- //! Permutations of reference points during tree building.
427
+ // Permutations of reference points during tree building.
345
428
  std::vector<size_t> oldFromNewReferences;
346
- //! Pointer to the root of the reference tree.
429
+ // Pointer to the root of the reference tree.
347
430
  Tree* referenceTree;
348
- //! Reference dataset. In some situations we may be the owner of this.
431
+ // Reference dataset. In some situations we may be the owner of this.
349
432
  const MatType* referenceSet;
350
433
 
351
- //! Indicates the neighbor search mode.
434
+ // This is only kept for reverse compatibility and will be removed in mlpack
435
+ // 5.0.0.
352
436
  NeighborSearchMode searchMode;
353
- //! Indicates the relative error to be considered in approximate search.
437
+ bool searchModeMod; // also for reverse compatibility
438
+ // Indicates the neighbor search strategy.
439
+ NeighborSearchStrategy searchStrategy;
440
+ // Indicates the relative error to be considered in approximate search.
354
441
  double epsilon;
355
442
 
356
- //! Instantiation of distance metric.
443
+ // Instantiation of distance metric.
357
444
  DistanceType distance;
358
445
 
359
- //! The total number of base cases.
446
+ // The total number of base cases.
360
447
  size_t baseCases;
361
- //! The total number of scores (applicable for non-naive search).
448
+ // The total number of scores (applicable for non-naive search).
362
449
  size_t scores;
363
450
 
364
- //! If this is true, the reference tree bounds need to be reset on a call to
365
- //! Search() without a query set.
451
+ // If this is true, the reference tree bounds need to be reset on a call to
452
+ // Search() without a query set.
366
453
  bool treeNeedsReset;
367
454
 
368
- //! The NSModel class should have access to internal members.
455
+ // The NSModel class should have access to internal members.
369
456
  friend class LeafSizeNSWrapper<SortPolicy, TreeType, DualTreeTraversalType,
370
457
  SingleTreeTraversalType>;
371
458
  }; // class NeighborSearch
372
459
 
373
460
  } // namespace mlpack
374
461
 
462
+ // The CEREAL_TEMPLATE_CLASS_VERSION() macro does not work with template
463
+ // template parameters so we write it manually.
464
+ namespace cereal {
465
+ namespace detail {
466
+
467
+ template<typename SortPolicy,
468
+ typename DistanceType,
469
+ typename MatType,
470
+ template<typename TreeDistanceType,
471
+ typename TreeStatType,
472
+ typename TreeMatType> class TreeType,
473
+ template<typename RuleType> class DualTreeTraversalType,
474
+ template<typename RuleType> class SingleTreeTraversalType>
475
+ struct Version<mlpack::NeighborSearch<SortPolicy, DistanceType, MatType,
476
+ TreeType, DualTreeTraversalType, SingleTreeTraversalType>>
477
+ {
478
+ static std::uint32_t registerVersion()
479
+ {
480
+ ::cereal::detail::StaticObject<Versions>::getInstance().mapping.emplace(
481
+ std::type_index(typeid(mlpack::NeighborSearch<SortPolicy, DistanceType,
482
+ MatType, TreeType, DualTreeTraversalType,
483
+ SingleTreeTraversalType>)).hash_code(), 1);
484
+ return 1;
485
+ }
486
+
487
+ static inline const std::uint32_t version = registerVersion();
488
+
489
+ static void unused() { (void) version; }
490
+ }; /* end Version */
491
+
492
+ } // namespace detail
493
+ } // namespace cereal
494
+
375
495
  // Include implementation.
376
496
  #include "neighbor_search_impl.hpp"
377
497